Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
dbbd3ac8
Commit
dbbd3ac8
authored
Apr 08, 2025
by
Zhekai Zhang
Browse files
Support batch inference in flux model
parent
871f5272
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
29 deletions
+63
-29
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+8
-8
src/FluxModel.cpp
src/FluxModel.cpp
+49
-16
src/Linear.cpp
src/Linear.cpp
+3
-2
src/kernels/awq/gemv_awq.cu
src/kernels/awq/gemv_awq.cu
+2
-2
src/kernels/misc_kernels.cu
src/kernels/misc_kernels.cu
+1
-1
No files found.
nunchaku/models/transformers/transformer_flux.py
View file @
dbbd3ac8
...
@@ -69,7 +69,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -69,7 +69,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_single_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
skip_first_layer
=
False
,
skip_first_layer
=
False
,
):
):
batch_size
=
hidden_states
.
shape
[
0
]
#
batch_size = hidden_states.shape[0]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
...
@@ -95,9 +95,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -95,9 +95,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
assert
image_rotary_emb
.
shape
[
2
]
==
batch_size
*
(
txt_tokens
+
img_tokens
)
assert
image_rotary_emb
.
shape
[
2
]
==
1
*
(
txt_tokens
+
img_tokens
)
# [
bs
, tokens, head_dim / 2, 1, 2] (sincos)
# [
1
, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb
=
image_rotary_emb
.
reshape
([
batch_size
,
txt_tokens
+
img_tokens
,
*
image_rotary_emb
.
shape
[
3
:]])
image_rotary_emb
=
image_rotary_emb
.
reshape
([
1
,
txt_tokens
+
img_tokens
,
*
image_rotary_emb
.
shape
[
3
:]])
rotary_emb_txt
=
image_rotary_emb
[:,
:
txt_tokens
,
...]
# .to(self.dtype)
rotary_emb_txt
=
image_rotary_emb
[:,
:
txt_tokens
,
...]
# .to(self.dtype)
rotary_emb_img
=
image_rotary_emb
[:,
txt_tokens
:,
...]
# .to(self.dtype)
rotary_emb_img
=
image_rotary_emb
[:,
txt_tokens
:,
...]
# .to(self.dtype)
rotary_emb_single
=
image_rotary_emb
# .to(self.dtype)
rotary_emb_single
=
image_rotary_emb
# .to(self.dtype)
...
@@ -135,7 +135,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -135,7 +135,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_block_samples
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
):
):
batch_size
=
hidden_states
.
shape
[
0
]
#
batch_size = hidden_states.shape[0]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
...
@@ -155,9 +155,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -155,9 +155,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
assert
image_rotary_emb
.
shape
[
1
]
==
1
assert
image_rotary_emb
.
shape
[
2
]
==
batch_size
*
(
txt_tokens
+
img_tokens
)
assert
image_rotary_emb
.
shape
[
2
]
==
1
*
(
txt_tokens
+
img_tokens
)
# [
bs
, tokens, head_dim / 2, 1, 2] (sincos)
# [
1
, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb
=
image_rotary_emb
.
reshape
([
batch_size
,
txt_tokens
+
img_tokens
,
*
image_rotary_emb
.
shape
[
3
:]])
image_rotary_emb
=
image_rotary_emb
.
reshape
([
1
,
txt_tokens
+
img_tokens
,
*
image_rotary_emb
.
shape
[
3
:]])
rotary_emb_txt
=
image_rotary_emb
[:,
:
txt_tokens
,
...]
# .to(self.dtype)
rotary_emb_txt
=
image_rotary_emb
[:,
:
txt_tokens
,
...]
# .to(self.dtype)
rotary_emb_img
=
image_rotary_emb
[:,
txt_tokens
:,
...]
# .to(self.dtype)
rotary_emb_img
=
image_rotary_emb
[:,
txt_tokens
:,
...]
# .to(self.dtype)
...
...
src/FluxModel.cpp
View file @
dbbd3ac8
...
@@ -60,7 +60,8 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
...
@@ -60,7 +60,8 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
Tensor
norm_x
=
norm
.
forward
(
x
);
Tensor
norm_x
=
norm
.
forward
(
x
);
debug
(
"norm_x"
,
norm_x
);
debug
(
"norm_x"
,
norm_x
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels
::
mul_add_batch
(
norm_x
,
scale_msa
,
true
,
0.0
,
shift_msa
,
true
);
return
Output
{
norm_x
,
gate_msa
};
return
Output
{
norm_x
,
gate_msa
};
}
}
...
@@ -89,7 +90,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
...
@@ -89,7 +90,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Tensor
norm_x
=
norm
.
forward
(
x
);
Tensor
norm_x
=
norm
.
forward
(
x
);
debug
(
"norm_x"
,
norm_x
);
debug
(
"norm_x"
,
norm_x
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels
::
mul_add_batch
(
norm_x
,
scale_msa
,
true
,
0.0
,
shift_msa
,
true
);
debug
(
"norm_x_scaled"
,
norm_x
);
debug
(
"norm_x_scaled"
,
norm_x
);
return
Output
{
norm_x
};
return
Output
{
norm_x
};
...
@@ -100,7 +102,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
...
@@ -100,7 +102,8 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Tensor
norm_x
=
norm
.
forward
(
x
);
Tensor
norm_x
=
norm
.
forward
(
x
);
debug
(
"norm_x"
,
norm_x
);
debug
(
"norm_x"
,
norm_x
);
kernels
::
mul_add
(
norm_x
,
scale_msa
,
shift_msa
);
// kernels::mul_add(norm_x, scale_msa, shift_msa);
kernels
::
mul_add_batch
(
norm_x
,
scale_msa
,
true
,
0.0
,
shift_msa
,
true
);
debug
(
"norm_x_scaled"
,
norm_x
);
debug
(
"norm_x_scaled"
,
norm_x
);
return
Output
{
norm_x
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
};
return
Output
{
norm_x
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
};
...
@@ -335,7 +338,9 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -335,7 +338,9 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
// qkv_proj.forward(norm_hidden_states, qkv, {});
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
// debug("qkv_raw", qkv);
qkv_proj
.
forward
(
norm_hidden_states
,
qkv
,
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
qkv
.
slice
(
0
,
i
,
i
+
1
),
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
}
debug
(
"qkv"
,
qkv
);
debug
(
"qkv"
,
qkv
);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
...
@@ -343,7 +348,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -343,7 +348,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
attn_output
=
attn
.
forward
(
qkv
);
attn_output
=
attn
.
forward
(
qkv
);
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
assert
(
batch_size
==
1
);
//
assert(batch_size == 1);
const
int
num_tokens_pad
=
ceilDiv
(
num_tokens
,
256
)
*
256
;
const
int
num_tokens_pad
=
ceilDiv
(
num_tokens
,
256
)
*
256
;
...
@@ -351,7 +356,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -351,7 +356,14 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor
k
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
k
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
v
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
v
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
qkv_proj
.
forward
(
norm_hidden_states
,
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
q
,
k
,
v
,
num_tokens
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
qkv_proj
.
forward
(
norm_hidden_states
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
q
.
slice
(
0
,
i
,
i
+
1
),
k
.
slice
(
0
,
i
,
i
+
1
),
v
.
slice
(
0
,
i
,
i
+
1
),
num_tokens
);
}
debug
(
"packed_q"
,
q
);
debug
(
"packed_q"
,
q
);
debug
(
"packed_k"
,
k
);
debug
(
"packed_k"
,
k
);
...
@@ -361,7 +373,21 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -361,7 +373,21 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
kernels
::
attention_fp16
(
q
,
k
,
v
,
o
,
pow
(
dim_head
,
(
-
0.5
)));
kernels
::
attention_fp16
(
q
,
k
,
v
,
o
,
pow
(
dim_head
,
(
-
0.5
)));
attn_output
=
o
.
slice
(
1
,
0
,
num_tokens
);
if
(
batch_size
==
1
||
num_tokens_pad
==
num_tokens
)
{
attn_output
=
o
.
slice
(
1
,
0
,
num_tokens
);
}
else
{
attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
},
o
.
scalar_type
(),
o
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
attn_output
.
data_ptr
(),
attn_output
.
stride
(
0
)
*
attn_output
.
scalar_size
(),
o
.
data_ptr
(),
o
.
stride
(
0
)
*
o
.
scalar_size
(),
attn_output
.
stride
(
0
)
*
attn_output
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
getCurrentCUDAStream
()
));
}
}
else
{
}
else
{
assert
(
false
);
assert
(
false
);
}
}
...
@@ -379,7 +405,8 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -379,7 +405,8 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states
=
kernels
::
add
(
attn_output
,
ff_output
);
hidden_states
=
kernels
::
add
(
attn_output
,
ff_output
);
debug
(
"attn_ff_output"
,
hidden_states
);
debug
(
"attn_ff_output"
,
hidden_states
);
kernels
::
mul_add
(
hidden_states
,
gate
,
residual
);
// kernels::mul_add(hidden_states, gate, residual);
kernels
::
mul_add_batch
(
hidden_states
,
gate
,
true
,
0.0
,
residual
,
true
);
nvtxRangePop
();
nvtxRangePop
();
...
@@ -627,7 +654,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -627,7 +654,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug
(
"img.attn_output"
,
attn_output
);
debug
(
"img.attn_output"
,
attn_output
);
#if 1
#if 1
kernels
::
mul_add
(
attn_output
,
gate_msa
,
hidden_states
);
// kernels::mul_add(attn_output, gate_msa, hidden_states);
kernels
::
mul_add_batch
(
attn_output
,
gate_msa
,
true
,
0.0
,
hidden_states
,
true
);
hidden_states
=
std
::
move
(
attn_output
);
hidden_states
=
std
::
move
(
attn_output
);
nvtxRangePop
();
nvtxRangePop
();
...
@@ -638,7 +666,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -638,7 +666,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
norm_hidden_states
=
norm2
.
forward
(
hidden_states
);
Tensor
norm_hidden_states
=
norm2
.
forward
(
hidden_states
);
debug
(
"scale_mlp"
,
scale_mlp
);
debug
(
"scale_mlp"
,
scale_mlp
);
debug
(
"shift_mlp"
,
shift_mlp
);
debug
(
"shift_mlp"
,
shift_mlp
);
kernels
::
mul_add
(
norm_hidden_states
,
scale_mlp
,
shift_mlp
);
// kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels
::
mul_add_batch
(
norm_hidden_states
,
scale_mlp
,
true
,
0.0
,
shift_mlp
,
true
);
spdlog
::
debug
(
"norm_hidden_states={}"
,
norm_hidden_states
.
shape
.
str
());
spdlog
::
debug
(
"norm_hidden_states={}"
,
norm_hidden_states
.
shape
.
str
());
#else
#else
...
@@ -651,7 +680,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -651,7 +680,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug
(
"img.ff_output"
,
ff_output
);
debug
(
"img.ff_output"
,
ff_output
);
debug
(
"gate_mlp"
,
gate_mlp
);
debug
(
"gate_mlp"
,
gate_mlp
);
kernels
::
mul_add
(
ff_output
,
gate_mlp
,
hidden_states
);
// kernels::mul_add(ff_output, gate_mlp, hidden_states);
kernels
::
mul_add_batch
(
ff_output
,
gate_mlp
,
true
,
0.0
,
hidden_states
,
true
);
hidden_states
=
std
::
move
(
ff_output
);
hidden_states
=
std
::
move
(
ff_output
);
nvtxRangePop
();
nvtxRangePop
();
...
@@ -692,7 +722,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -692,7 +722,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug
(
"context.attn_output"
,
attn_output
);
debug
(
"context.attn_output"
,
attn_output
);
#if 1
#if 1
kernels
::
mul_add
(
attn_output
,
gate_msa
,
encoder_hidden_states
);
// kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
kernels
::
mul_add_batch
(
attn_output
,
gate_msa
,
true
,
0.0
,
encoder_hidden_states
,
true
);
encoder_hidden_states
=
std
::
move
(
attn_output
);
encoder_hidden_states
=
std
::
move
(
attn_output
);
nvtxRangePop
();
nvtxRangePop
();
...
@@ -703,7 +734,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -703,7 +734,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
norm_hidden_states
=
norm2_context
.
forward
(
encoder_hidden_states
);
Tensor
norm_hidden_states
=
norm2_context
.
forward
(
encoder_hidden_states
);
debug
(
"c_scale_mlp"
,
scale_mlp
);
debug
(
"c_scale_mlp"
,
scale_mlp
);
debug
(
"c_shift_mlp"
,
shift_mlp
);
debug
(
"c_shift_mlp"
,
shift_mlp
);
kernels
::
mul_add
(
norm_hidden_states
,
scale_mlp
,
shift_mlp
);
// kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels
::
mul_add_batch
(
norm_hidden_states
,
scale_mlp
,
true
,
0.0
,
shift_mlp
,
true
);
spdlog
::
debug
(
"norm_hidden_states={}"
,
norm_hidden_states
.
shape
.
str
());
spdlog
::
debug
(
"norm_hidden_states={}"
,
norm_hidden_states
.
shape
.
str
());
#else
#else
...
@@ -718,7 +750,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -718,7 +750,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug
(
"context.ff_output"
,
ff_output
);
debug
(
"context.ff_output"
,
ff_output
);
debug
(
"c_gate_mlp"
,
gate_mlp
);
debug
(
"c_gate_mlp"
,
gate_mlp
);
kernels
::
mul_add
(
ff_output
,
gate_mlp
,
encoder_hidden_states
);
// kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
kernels
::
mul_add_batch
(
ff_output
,
gate_mlp
,
true
,
0.0
,
encoder_hidden_states
,
true
);
encoder_hidden_states
=
std
::
move
(
ff_output
);
encoder_hidden_states
=
std
::
move
(
ff_output
);
nvtxRangePop
();
nvtxRangePop
();
...
@@ -791,8 +824,8 @@ Tensor FluxModel::forward(
...
@@ -791,8 +824,8 @@ Tensor FluxModel::forward(
// txt first, same as diffusers
// txt first, same as diffusers
concat
=
Tensor
::
allocate
({
batch_size
,
txt_tokens
+
img_tokens
,
3072
},
dtype
,
device
);
concat
=
Tensor
::
allocate
({
batch_size
,
txt_tokens
+
img_tokens
,
3072
},
dtype
,
device
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
txt_tokens
).
copy_
(
encoder_hidden_states
);
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
txt_tokens
).
copy_
(
encoder_hidden_states
.
slice
(
0
,
i
,
i
+
1
)
);
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
hidden_states
);
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
txt_tokens
,
txt_tokens
+
img_tokens
).
copy_
(
hidden_states
.
slice
(
0
,
i
,
i
+
1
)
);
}
}
hidden_states
=
concat
;
hidden_states
=
concat
;
encoder_hidden_states
=
{};
encoder_hidden_states
=
{};
...
...
src/Linear.cpp
View file @
dbbd3ac8
...
@@ -73,8 +73,9 @@ Tensor GEMV_AWQ::forward(Tensor x) {
...
@@ -73,8 +73,9 @@ Tensor GEMV_AWQ::forward(Tensor x) {
Tensor
out
=
gemv_awq
(
x
,
this
->
qweight
,
this
->
wscales
,
this
->
wzeros
,
M
,
out_features
,
in_features
,
group_size
);
Tensor
out
=
gemv_awq
(
x
,
this
->
qweight
,
this
->
wscales
,
this
->
wzeros
,
M
,
out_features
,
in_features
,
group_size
);
if
(
bias
.
valid
())
{
if
(
bias
.
valid
())
{
// TODO: batch
// TODO: batch
assert
(
out
.
numel
()
==
bias
.
numel
());
// assert(out.numel() == bias.numel());
out
=
kernels
::
add
(
out
,
bias
.
view
(
out
.
shape
.
dataExtent
));
// out = kernels::add(out, bias.view(out.shape.dataExtent));
kernels
::
mul_add_batch
(
out
,
{},
false
,
0.0
,
bias
,
false
);
}
}
debug
(
"out_before_lora"
,
out
);
debug
(
"out_before_lora"
,
out
);
...
...
src/kernels/awq/gemv_awq.cu
View file @
dbbd3ac8
...
@@ -303,10 +303,10 @@ Tensor gemv_awq(
...
@@ -303,10 +303,10 @@ Tensor gemv_awq(
constexpr
int
GROUP_SIZE
=
64
;
constexpr
int
GROUP_SIZE
=
64
;
assert
(
m
>
0
&&
m
<
8
);
assert
(
m
>
0
&&
m
<
=
8
);
assert
(
group_size
==
GROUP_SIZE
);
assert
(
group_size
==
GROUP_SIZE
);
dispatchVal
(
m
,
std
::
make_integer_sequence
<
int
,
8
>
(),
[
&
]
<
int
M
>
()
{
dispatchVal
(
m
,
std
::
make_integer_sequence
<
int
,
9
>
(),
[
&
]
<
int
M
>
()
{
if
constexpr
(
M
==
0
)
{
if
constexpr
(
M
==
0
)
{
assert
(
false
);
assert
(
false
);
return
;
return
;
...
...
src/kernels/misc_kernels.cu
View file @
dbbd3ac8
...
@@ -180,7 +180,7 @@ std::array<Tensor, N> split_mod(Tensor input) {
...
@@ -180,7 +180,7 @@ std::array<Tensor, N> split_mod(Tensor input) {
auto
stream
=
getCurrentCUDAStream
();
auto
stream
=
getCurrentCUDAStream
();
auto
shapeOut
=
input
.
shape
;
auto
shapeOut
=
TensorShape
(
input
.
shape
.
dataExtent
)
;
shapeOut
[
-
1
]
/=
N
;
shapeOut
[
-
1
]
/=
N
;
std
::
array
<
Tensor
,
N
>
out
;
std
::
array
<
Tensor
,
N
>
out
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment