Commit dbbd3ac8 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Support batch inference in flux model

parent 871f5272
......@@ -69,7 +69,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_single_block_samples=None,
skip_first_layer=False,
):
batch_size = hidden_states.shape[0]
# batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
......@@ -95,9 +95,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
# [1, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb # .to(self.dtype)
......@@ -135,7 +135,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_block_samples=None,
controlnet_single_block_samples=None,
):
batch_size = hidden_states.shape[0]
# batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
......@@ -155,9 +155,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
# [1, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
......
......@@ -60,7 +60,8 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
return Output{norm_x, gate_msa};
}
......@@ -89,7 +90,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
debug("norm_x_scaled", norm_x);
return Output{norm_x};
......@@ -100,7 +102,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
debug("norm_x_scaled", norm_x);
return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
......@@ -335,7 +338,9 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
for (int i = 0; i < batch_size; i++) {
qkv_proj.forward(norm_hidden_states.slice(0, i, i+1), qkv.slice(0, i, i+1), {}, norm_q.weight, norm_k.weight, rotary_emb);
}
debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
......@@ -343,7 +348,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
attn_output = attn.forward(qkv);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
assert(batch_size == 1);
// assert(batch_size == 1);
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
......@@ -351,7 +356,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
qkv_proj.forward(norm_hidden_states, {}, {}, norm_q.weight, norm_k.weight, rotary_emb, q, k, v, num_tokens);
for (int i = 0; i < batch_size; i++) {
qkv_proj.forward(
norm_hidden_states.slice(0, i, i+1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
q.slice(0, i, i+1),
k.slice(0, i, i+1),
v.slice(0, i, i+1),
num_tokens);
}
debug("packed_q", q);
debug("packed_k", k);
......@@ -361,7 +373,21 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));
attn_output = o.slice(1, 0, num_tokens);
if (batch_size == 1 || num_tokens_pad == num_tokens) {
attn_output = o.slice(1, 0, num_tokens);
} else {
attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
checkCUDA(cudaMemcpy2DAsync(
attn_output.data_ptr(),
attn_output.stride(0) * attn_output.scalar_size(),
o.data_ptr(),
o.stride(0) * o.scalar_size(),
attn_output.stride(0) * attn_output.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
getCurrentCUDAStream()
));
}
} else {
assert(false);
}
......@@ -379,7 +405,8 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states = kernels::add(attn_output, ff_output);
debug("attn_ff_output", hidden_states);
kernels::mul_add(hidden_states, gate, residual);
// kernels::mul_add(hidden_states, gate, residual);
kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
nvtxRangePop();
......@@ -627,7 +654,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("img.attn_output", attn_output);
#if 1
kernels::mul_add(attn_output, gate_msa, hidden_states);
// kernels::mul_add(attn_output, gate_msa, hidden_states);
kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
hidden_states = std::move(attn_output);
nvtxRangePop();
......@@ -638,7 +666,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor norm_hidden_states = norm2.forward(hidden_states);
debug("scale_mlp", scale_mlp);
debug("shift_mlp", shift_mlp);
kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
// kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
......@@ -651,7 +680,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("img.ff_output", ff_output);
debug("gate_mlp", gate_mlp);
kernels::mul_add(ff_output, gate_mlp, hidden_states);
// kernels::mul_add(ff_output, gate_mlp, hidden_states);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
hidden_states = std::move(ff_output);
nvtxRangePop();
......@@ -692,7 +722,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("context.attn_output", attn_output);
#if 1
kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
// kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
encoder_hidden_states = std::move(attn_output);
nvtxRangePop();
......@@ -703,7 +734,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
debug("c_scale_mlp", scale_mlp);
debug("c_shift_mlp", shift_mlp);
kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
// kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
......@@ -718,7 +750,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("context.ff_output", ff_output);
debug("c_gate_mlp", gate_mlp);
kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
// kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
encoder_hidden_states = std::move(ff_output);
nvtxRangePop();
......@@ -791,8 +824,8 @@ Tensor FluxModel::forward(
// txt first, same as diffusers
concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states.slice(0, i, i + 1));
}
hidden_states = concat;
encoder_hidden_states = {};
......
......@@ -73,8 +73,9 @@ Tensor GEMV_AWQ::forward(Tensor x) {
Tensor out = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size);
if (bias.valid()) {
// TODO: batch
assert(out.numel() == bias.numel());
out = kernels::add(out, bias.view(out.shape.dataExtent));
// assert(out.numel() == bias.numel());
// out = kernels::add(out, bias.view(out.shape.dataExtent));
kernels::mul_add_batch(out, {}, false, 0.0, bias, false);
}
debug("out_before_lora", out);
......
......@@ -303,10 +303,10 @@ Tensor gemv_awq(
constexpr int GROUP_SIZE = 64;
assert(m > 0 && m < 8);
assert(m > 0 && m <= 8);
assert(group_size == GROUP_SIZE);
dispatchVal(m, std::make_integer_sequence<int, 8>(), [&]<int M>() {
dispatchVal(m, std::make_integer_sequence<int, 9>(), [&]<int M>() {
if constexpr (M == 0) {
assert(false);
return;
......
......@@ -180,7 +180,7 @@ std::array<Tensor, N> split_mod(Tensor input) {
auto stream = getCurrentCUDAStream();
auto shapeOut = input.shape;
auto shapeOut = TensorShape(input.shape.dataExtent);
shapeOut[-1] /= N;
std::array<Tensor, N> out;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment