Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
11c6e423
Commit
11c6e423
authored
Sep 03, 2025
by
PanZezhong1725
Browse files
fix jiuge_awq qk_buf
parent
5330d5fa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
src/models/jiuge_awq/jiuge_awq.cpp
src/models/jiuge_awq/jiuge_awq.cpp
+3
-3
No files found.
src/models/jiuge_awq/jiuge_awq.cpp
View file @
11c6e423
...
@@ -118,7 +118,7 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc,
...
@@ -118,7 +118,7 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc,
max_seq_len
=
std
::
max
(
max_seq_len
,
size_t
(
seq_len
));
max_seq_len
=
std
::
max
(
max_seq_len
,
size_t
(
seq_len
));
}
}
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_qk_size
},
rsrc
.
memory_pool
);
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
*
max_qk_size
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
q_rearrange
=
rearrange_q_buf
->
view
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
q_rearrange
=
rearrange_q_buf
->
view
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
...
@@ -158,11 +158,11 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc,
...
@@ -158,11 +158,11 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc,
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
// qk
// qk
rearrange
(
q_rearrange
->
slice
(
2
,
0
,
seq_len
),
q
);
rearrange
(
q_rearrange
->
slice
(
2
,
0
,
seq_len
),
q
);
auto
qk_gemm
=
qk_buf
->
slice
(
1
,
0
,
seq_len
*
total_len
)
->
view
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
qk_gemm
=
qk_buf
->
slice
(
0
,
0
,
nh
*
seq_len
*
total_len
)
->
view
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
linear
(
qk_gemm
,
rearrange_q_buf
->
slice
(
1
,
0
,
ngroup
*
seq_len
),
k_gemm
,
1.
f
/
float
(
sqrt
(
dh
)),
0.
f
,
nullptr
,
nullptr
);
linear
(
qk_gemm
,
rearrange_q_buf
->
slice
(
1
,
0
,
ngroup
*
seq_len
),
k_gemm
,
1.
f
/
float
(
sqrt
(
dh
)),
0.
f
,
nullptr
,
nullptr
);
// softmax
// softmax
auto
qk_softmax
=
qk_
buf
->
slice
(
1
,
0
,
seq_len
*
total_len
)
->
view
({
nh
,
seq_len
,
total_len
});
auto
qk_softmax
=
qk_
gemm
->
view
({
nh
,
seq_len
,
total_len
});
causalSoftmax
(
qk_softmax
,
qk_softmax
);
causalSoftmax
(
qk_softmax
,
qk_softmax
);
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
linear
(
attn_val_buf
->
slice
(
1
,
0
,
ngroup
*
seq_len
),
qk_gemm
,
v_gemm
,
1.
f
,
0.
f
,
nullptr
,
nullptr
);
linear
(
attn_val_buf
->
slice
(
1
,
0
,
ngroup
*
seq_len
),
qk_gemm
,
v_gemm
,
1.
f
,
0.
f
,
nullptr
,
nullptr
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment