Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
83b7542d
Commit
83b7542d
authored
Apr 03, 2025
by
sxtyzhangzk
Committed by
Zhekai Zhang
Apr 04, 2025
Browse files
Fix resolution issue in flashattn2
parent
bf3669dd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
2 deletions
+35
-2
src/FluxModel.cpp
src/FluxModel.cpp
+34
-2
src/FluxModel.h
src/FluxModel.h
+1
-0
No files found.
src/FluxModel.cpp
View file @
83b7542d
...
...
@@ -118,6 +118,33 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
headmask_type
=
headmask_type
.
copy
(
device
);
}
Tensor
Attention
::
forward
(
Tensor
qkv
)
{
assert
(
qkv
.
ndims
()
==
3
);
const
Device
device
=
qkv
.
device
();
const
int
batch_size
=
qkv
.
shape
[
0
];
const
int
num_tokens
=
qkv
.
shape
[
1
];
assert
(
qkv
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
Tensor
reshaped
=
qkv
.
view
({
batch_size
,
num_tokens
,
num_heads
*
3
,
dim_head
});
Tensor
q
=
reshaped
.
slice
(
2
,
0
,
num_heads
);
Tensor
k
=
reshaped
.
slice
(
2
,
num_heads
,
num_heads
*
2
);
Tensor
v
=
reshaped
.
slice
(
2
,
num_heads
*
2
,
num_heads
*
3
);
Tensor
raw_attn_output
=
mha_fwd
(
q
,
k
,
v
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
-
1
,
-
1
,
false
).
front
();
assert
(
raw_attn_output
.
shape
[
0
]
==
batch_size
);
assert
(
raw_attn_output
.
shape
[
1
]
==
num_tokens
);
assert
(
raw_attn_output
.
shape
[
2
]
==
num_heads
);
assert
(
raw_attn_output
.
shape
[
3
]
==
dim_head
);
return
raw_attn_output
.
view
({
batch_size
*
num_tokens
,
num_heads
,
dim_head
});
}
Tensor
Attention
::
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
)
{
const
bool
cast_fp16
=
this
->
force_fp16
&&
qkv
.
scalar_type
()
!=
Tensor
::
FP16
;
...
...
@@ -312,7 +339,8 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug
(
"qkv"
,
qkv
);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
attn_output
=
attn
.
forward
(
qkv
,
{},
0
);
// attn_output = attn.forward(qkv, {}, 0);
attn_output
=
attn
.
forward
(
qkv
);
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
assert
(
batch_size
==
1
);
...
...
@@ -501,7 +529,11 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePushA
(
"Attention"
);
raw_attn_output
=
attn
.
forward
(
concat
,
pool
,
sparsityRatio
);
if
(
pool
.
valid
())
{
raw_attn_output
=
attn
.
forward
(
concat
,
pool
,
sparsityRatio
);
}
else
{
raw_attn_output
=
attn
.
forward
(
concat
);
}
nvtxRangePop
();
...
...
src/FluxModel.h
View file @
83b7542d
...
...
@@ -63,6 +63,7 @@ public:
static
constexpr
int
POOL_SIZE
=
128
;
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
);
Tensor
forward
(
Tensor
qkv
);
Tensor
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
);
static
void
setForceFP16
(
Module
*
module
,
bool
value
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment