Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a2bb4757
Commit
a2bb4757
authored
Jan 13, 2023
by
fsx950223
Browse files
fix example
parent
5509e684
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
18 deletions
+18
-18
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
...oftmax_gemm/grouped_multihead_attention_backward_fp16.cpp
+18
-18
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
View file @
a2bb4757
...
@@ -287,15 +287,15 @@ int run(int argc, char* argv[])
...
@@ -287,15 +287,15 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceGemmInstance
::
ProblemDesc
>
problem_descs
;
std
::
vector
<
DeviceGemmInstance
::
ProblemDesc
>
problem_descs
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
const
DataType
*>
p_q
;
std
::
vector
<
const
void
*>
p_q
;
std
::
vector
<
const
DataType
*>
p_k
;
std
::
vector
<
const
void
*>
p_k
;
std
::
vector
<
const
DataType
*>
p_v
;
std
::
vector
<
const
void
*>
p_v
;
std
::
vector
<
const
DataType
*>
p_y
;
std
::
vector
<
const
void
*>
p_y
;
std
::
vector
<
const
LSEDataType
*>
p_lse
;
std
::
vector
<
const
void
*>
p_lse
;
std
::
vector
<
DataType
*>
p_qgrad
;
std
::
vector
<
void
*>
p_qgrad
;
std
::
vector
<
DataType
*>
p_kgrad
;
std
::
vector
<
void
*>
p_kgrad
;
std
::
vector
<
DataType
*>
p_vgrad
;
std
::
vector
<
void
*>
p_vgrad
;
std
::
vector
<
const
DataType
*>
p_ygrad
;
std
::
vector
<
const
void
*>
p_ygrad
;
std
::
vector
<
Tensor
<
DataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
DataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
DataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
DataType
>>
k_g_n_ks
;
...
@@ -517,15 +517,15 @@ int run(int argc, char* argv[])
...
@@ -517,15 +517,15 @@ int run(int argc, char* argv[])
kgrad_tensors_device
.
back
()
->
SetZero
();
kgrad_tensors_device
.
back
()
->
SetZero
();
vgrad_tensors_device
.
back
()
->
SetZero
();
vgrad_tensors_device
.
back
()
->
SetZero
();
ygrad_tensors_device
.
back
()
->
ToDevice
(
ygrad_gs_ms_os
.
data
());
ygrad_tensors_device
.
back
()
->
ToDevice
(
ygrad_gs_ms_os
.
data
());
p_q
.
push_back
(
static_cast
<
DataType
*>
(
q_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_q
.
push_back
(
q_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_k
.
push_back
(
static_cast
<
DataType
*>
(
k_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_k
.
push_back
(
k_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_v
.
push_back
(
static_cast
<
DataType
*>
(
v_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_v
.
push_back
(
v_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_y
.
push_back
(
static_cast
<
DataType
*>
(
y_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_y
.
push_back
(
y_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_lse
.
push_back
(
static_cast
<
LSEDataType
*>
(
lse_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_lse
.
push_back
(
lse_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_kgrad
.
push_back
(
static_cast
<
DataType
*>
(
kgrad_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_kgrad
.
push_back
(
kgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_vgrad
.
push_back
(
static_cast
<
DataType
*>
(
vgrad_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_vgrad
.
push_back
(
vgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_ygrad
.
push_back
(
static_cast
<
DataType
*>
(
ygrad_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_ygrad
.
push_back
(
ygrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_qgrad
.
push_back
(
static_cast
<
DataType
*>
(
qgrad_tensors_device
.
back
()
->
GetDeviceBuffer
())
)
;
p_qgrad
.
push_back
(
qgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
}
}
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
p_q
,
p_q
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment