Commit dbbd3ac8 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Support batch inference in flux model

parent 871f5272
...@@ -69,7 +69,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -69,7 +69,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
skip_first_layer=False, skip_first_layer=False,
): ):
batch_size = hidden_states.shape[0] # batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1] txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1] img_tokens = hidden_states.shape[1]
...@@ -95,9 +95,9 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -95,9 +95,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1 assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens) assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos) # [1, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]) image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype) rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype) rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb # .to(self.dtype) rotary_emb_single = image_rotary_emb # .to(self.dtype)
...@@ -135,7 +135,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -135,7 +135,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_block_samples=None, controlnet_block_samples=None,
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
): ):
batch_size = hidden_states.shape[0] # batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1] txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1] img_tokens = hidden_states.shape[1]
...@@ -155,9 +155,9 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -155,9 +155,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1 assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens) assert image_rotary_emb.shape[2] == 1 * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos) # [1, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]]) image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype) rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype) rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
......
...@@ -60,7 +60,8 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor ...@@ -60,7 +60,8 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
Tensor norm_x = norm.forward(x); Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x); debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa); // kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
return Output{norm_x, gate_msa}; return Output{norm_x, gate_msa};
} }
...@@ -89,7 +90,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { ...@@ -89,7 +90,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Tensor norm_x = norm.forward(x); Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x); debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa); // kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
debug("norm_x_scaled", norm_x); debug("norm_x_scaled", norm_x);
return Output{norm_x}; return Output{norm_x};
...@@ -100,7 +102,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { ...@@ -100,7 +102,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Tensor norm_x = norm.forward(x); Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x); debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa); // kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
debug("norm_x_scaled", norm_x); debug("norm_x_scaled", norm_x);
return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp}; return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
...@@ -335,7 +338,9 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -335,7 +338,9 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
// qkv_proj.forward(norm_hidden_states, qkv, {}); // qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv); // debug("qkv_raw", qkv);
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb); for (int i = 0; i < batch_size; i++) {
qkv_proj.forward(norm_hidden_states.slice(0, i, i+1), qkv.slice(0, i, i+1), {}, norm_q.weight, norm_k.weight, rotary_emb);
}
debug("qkv", qkv); debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states); // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
...@@ -343,7 +348,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -343,7 +348,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
attn_output = attn.forward(qkv); attn_output = attn.forward(qkv);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head}); attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) { } else if (attnImpl == AttentionImpl::NunchakuFP16) {
assert(batch_size == 1); // assert(batch_size == 1);
const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256; const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;
...@@ -351,7 +356,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -351,7 +356,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device()); Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device()); Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
qkv_proj.forward(norm_hidden_states, {}, {}, norm_q.weight, norm_k.weight, rotary_emb, q, k, v, num_tokens); for (int i = 0; i < batch_size; i++) {
qkv_proj.forward(
norm_hidden_states.slice(0, i, i+1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
q.slice(0, i, i+1),
k.slice(0, i, i+1),
v.slice(0, i, i+1),
num_tokens);
}
debug("packed_q", q); debug("packed_q", q);
debug("packed_k", k); debug("packed_k", k);
...@@ -361,7 +373,21 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -361,7 +373,21 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5))); kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));
attn_output = o.slice(1, 0, num_tokens); if (batch_size == 1 || num_tokens_pad == num_tokens) {
attn_output = o.slice(1, 0, num_tokens);
} else {
attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
checkCUDA(cudaMemcpy2DAsync(
attn_output.data_ptr(),
attn_output.stride(0) * attn_output.scalar_size(),
o.data_ptr(),
o.stride(0) * o.scalar_size(),
attn_output.stride(0) * attn_output.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
getCurrentCUDAStream()
));
}
} else { } else {
assert(false); assert(false);
} }
...@@ -379,7 +405,8 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -379,7 +405,8 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states = kernels::add(attn_output, ff_output); hidden_states = kernels::add(attn_output, ff_output);
debug("attn_ff_output", hidden_states); debug("attn_ff_output", hidden_states);
kernels::mul_add(hidden_states, gate, residual); // kernels::mul_add(hidden_states, gate, residual);
kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
nvtxRangePop(); nvtxRangePop();
...@@ -627,7 +654,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -627,7 +654,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("img.attn_output", attn_output); debug("img.attn_output", attn_output);
#if 1 #if 1
kernels::mul_add(attn_output, gate_msa, hidden_states); // kernels::mul_add(attn_output, gate_msa, hidden_states);
kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
hidden_states = std::move(attn_output); hidden_states = std::move(attn_output);
nvtxRangePop(); nvtxRangePop();
...@@ -638,7 +666,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -638,7 +666,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor norm_hidden_states = norm2.forward(hidden_states); Tensor norm_hidden_states = norm2.forward(hidden_states);
debug("scale_mlp", scale_mlp); debug("scale_mlp", scale_mlp);
debug("shift_mlp", shift_mlp); debug("shift_mlp", shift_mlp);
kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp); // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str()); spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else #else
...@@ -651,7 +680,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -651,7 +680,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("img.ff_output", ff_output); debug("img.ff_output", ff_output);
debug("gate_mlp", gate_mlp); debug("gate_mlp", gate_mlp);
kernels::mul_add(ff_output, gate_mlp, hidden_states); // kernels::mul_add(ff_output, gate_mlp, hidden_states);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
hidden_states = std::move(ff_output); hidden_states = std::move(ff_output);
nvtxRangePop(); nvtxRangePop();
...@@ -692,7 +722,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -692,7 +722,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("context.attn_output", attn_output); debug("context.attn_output", attn_output);
#if 1 #if 1
kernels::mul_add(attn_output, gate_msa, encoder_hidden_states); // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
encoder_hidden_states = std::move(attn_output); encoder_hidden_states = std::move(attn_output);
nvtxRangePop(); nvtxRangePop();
...@@ -703,7 +734,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -703,7 +734,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states); Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
debug("c_scale_mlp", scale_mlp); debug("c_scale_mlp", scale_mlp);
debug("c_shift_mlp", shift_mlp); debug("c_shift_mlp", shift_mlp);
kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp); // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str()); spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else #else
...@@ -718,7 +750,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -718,7 +750,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("context.ff_output", ff_output); debug("context.ff_output", ff_output);
debug("c_gate_mlp", gate_mlp); debug("c_gate_mlp", gate_mlp);
kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states); // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
encoder_hidden_states = std::move(ff_output); encoder_hidden_states = std::move(ff_output);
nvtxRangePop(); nvtxRangePop();
...@@ -791,8 +824,8 @@ Tensor FluxModel::forward( ...@@ -791,8 +824,8 @@ Tensor FluxModel::forward(
// txt first, same as diffusers // txt first, same as diffusers
concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device); concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states); concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states); concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states.slice(0, i, i + 1));
} }
hidden_states = concat; hidden_states = concat;
encoder_hidden_states = {}; encoder_hidden_states = {};
......
...@@ -73,8 +73,9 @@ Tensor GEMV_AWQ::forward(Tensor x) { ...@@ -73,8 +73,9 @@ Tensor GEMV_AWQ::forward(Tensor x) {
Tensor out = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size); Tensor out = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size);
if (bias.valid()) { if (bias.valid()) {
// TODO: batch // TODO: batch
assert(out.numel() == bias.numel()); // assert(out.numel() == bias.numel());
out = kernels::add(out, bias.view(out.shape.dataExtent)); // out = kernels::add(out, bias.view(out.shape.dataExtent));
kernels::mul_add_batch(out, {}, false, 0.0, bias, false);
} }
debug("out_before_lora", out); debug("out_before_lora", out);
......
...@@ -303,10 +303,10 @@ Tensor gemv_awq( ...@@ -303,10 +303,10 @@ Tensor gemv_awq(
constexpr int GROUP_SIZE = 64; constexpr int GROUP_SIZE = 64;
assert(m > 0 && m < 8); assert(m > 0 && m <= 8);
assert(group_size == GROUP_SIZE); assert(group_size == GROUP_SIZE);
dispatchVal(m, std::make_integer_sequence<int, 8>(), [&]<int M>() { dispatchVal(m, std::make_integer_sequence<int, 9>(), [&]<int M>() {
if constexpr (M == 0) { if constexpr (M == 0) {
assert(false); assert(false);
return; return;
......
...@@ -180,7 +180,7 @@ std::array<Tensor, N> split_mod(Tensor input) { ...@@ -180,7 +180,7 @@ std::array<Tensor, N> split_mod(Tensor input) {
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
auto shapeOut = input.shape; auto shapeOut = TensorShape(input.shape.dataExtent);
shapeOut[-1] /= N; shapeOut[-1] /= N;
std::array<Tensor, N> out; std::array<Tensor, N> out;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment