Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
c3d5efa5
Commit
c3d5efa5
authored
May 21, 2025
by
PanZezhong
Browse files
support qkv bias
parent
5540d53a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
2 deletions
+17
-2
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+17
-2
No files found.
src/models/jiuge/jiuge.cpp
View file @
c3d5efa5
...
@@ -77,10 +77,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -77,10 +77,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
dh
=
meta
.
dh
;
auto
dh
=
meta
.
dh
;
auto
d
=
meta
.
d
;
auto
d
=
meta
.
d
;
auto
dt_logits
=
meta
.
dt_logits
;
auto
dt_logits
=
meta
.
dt_logits
;
// std::cout << "dt_logits: " <<(int)dt_logits << std::endl;
auto
di
=
meta
.
di
/
ndev
;
auto
di
=
meta
.
di
/
ndev
;
auto
dvoc
=
meta
.
dvoc
;
auto
dvoc
=
meta
.
dvoc
;
auto
stream
=
rsrc
.
stream
;
auto
stream
=
rsrc
.
stream
;
bool
has_qkv_bias
=
rsrc
.
b_attn_qkv
.
size
()
>
0
;
// Allocate buffers
// Allocate buffers
auto
logits_in
=
Tensor
::
buffer
(
dt_logits
,
{
ntok
,
d
},
stream
);
auto
logits_in
=
Tensor
::
buffer
(
dt_logits
,
{
ntok
,
d
},
stream
);
...
@@ -128,6 +128,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -128,6 +128,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
// Attention
// Attention
infiniopGemmDescriptor_t
desc_attn_qkv
,
desc_attn_o
;
infiniopGemmDescriptor_t
desc_attn_qkv
,
desc_attn_o
;
infiniopRearrangeDescriptor_t
desc_qkv_bias
;
if
(
has_qkv_bias
)
{
RUN_INFINI
(
infiniopCreateRearrangeDescriptor
(
rsrc
.
handle
,
&
desc_qkv_bias
,
qkv_buf
->
desc
()
->
get
(),
TensorDesc
::
create
(
dt_logits
,
{
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
})
->
get
()));
}
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_attn_qkv
,
qkv_buf
->
desc
()
->
get
(),
rsrc
.
handle
,
&
desc_attn_qkv
,
qkv_buf
->
desc
()
->
get
(),
logits_in
->
desc
()
->
get
(),
rsrc
.
w_attn_qkv
[
0
]
->
desc
()
->
get
()));
logits_in
->
desc
()
->
get
(),
rsrc
.
w_attn_qkv
[
0
]
->
desc
()
->
get
()));
...
@@ -224,7 +230,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -224,7 +230,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
// Allocate workspace
// Allocate workspace
workspace
=
rsrc
.
workspace_allocator
->
alloc
(
workspace_size
);
workspace
=
rsrc
.
workspace_allocator
->
alloc
(
workspace_size
);
// Compute
// Compute
for
(
uint32_t
layer
=
0
;
layer
<
nlayer
;
layer
++
)
{
for
(
uint32_t
layer
=
0
;
layer
<
nlayer
;
layer
++
)
{
// 1. Attention
// 1. Attention
...
@@ -234,6 +240,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -234,6 +240,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
logits_out
->
data
(),
logits_in
->
data
(),
logits_out
->
data
(),
logits_in
->
data
(),
rsrc
.
w_attn_norm
[
layer
]
->
data
(),
stream
));
rsrc
.
w_attn_norm
[
layer
]
->
data
(),
stream
));
// qkv_proj
// qkv_proj
if
(
has_qkv_bias
)
{
RUN_INFINI
(
infiniopRearrange
(
desc_qkv_bias
,
qkv_buf
->
data
(),
rsrc
.
b_attn_qkv
.
data
(),
stream
));
}
RUN_INFINI
(
infiniopGemm
(
RUN_INFINI
(
infiniopGemm
(
desc_attn_qkv
,
workspace
,
workspace_size
,
desc_attn_qkv
,
workspace
,
workspace_size
,
qkv_buf
->
data
(),
logits_out
->
data
(),
qkv_buf
->
data
(),
logits_out
->
data
(),
...
@@ -347,6 +358,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -347,6 +358,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// Clean up
// Clean up
infiniopDestroyRMSNormDescriptor
(
desc_norm
);
infiniopDestroyRMSNormDescriptor
(
desc_norm
);
infiniopDestroyRearrangeDescriptor
(
desc_qkv_bias
);
infiniopDestroyGemmDescriptor
(
desc_attn_qkv
);
infiniopDestroyGemmDescriptor
(
desc_attn_qkv
);
infiniopDestroyGemmDescriptor
(
desc_attn_o
);
infiniopDestroyGemmDescriptor
(
desc_attn_o
);
infiniopDestroyRoPEDescriptor
(
desc_rope_q
);
infiniopDestroyRoPEDescriptor
(
desc_rope_q
);
...
@@ -354,6 +366,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -354,6 +366,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
infiniopDestroyAttentionDescriptor
(
desc_attns
[
req
]);
infiniopDestroyAttentionDescriptor
(
desc_attns
[
req
]);
}
}
infiniopDestroyGemmDescriptor
(
desc_ffn_gate_up
);
infiniopDestroySwiGLUDescriptor
(
desc_swiglu
);
infiniopDestroyGemmDescriptor
(
desc_ffn_down
);
infiniopDestroyRMSNormDescriptor
(
desc_norm_out
);
infiniopDestroyRMSNormDescriptor
(
desc_norm_out
);
infiniopDestroyGemmDescriptor
(
desc_out_embd
);
infiniopDestroyGemmDescriptor
(
desc_out_embd
);
infiniopDestroyRandomSampleDescriptor
(
desc_sample
);
infiniopDestroyRandomSampleDescriptor
(
desc_sample
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment