Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c9915508
Commit
c9915508
authored
Aug 16, 2023
by
letaoqin
Browse files
fix group d data type
parent
98df59c6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
11 deletions
+10
-11
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
...tten_bias/grouped_multihead_attention_bias_forward_v2.cpp
+1
-2
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
...ten_bias/run_grouped_multihead_attention_bias_forward.inc
+9
-9
No files found.
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
View file @
c9915508
...
@@ -48,12 +48,11 @@ using ADataType = DataType;
...
@@ -48,12 +48,11 @@ using ADataType = DataType;
using
B0DataType
=
DataType
;
using
B0DataType
=
DataType
;
using
B1DataType
=
DataType
;
using
B1DataType
=
DataType
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
DDataType
=
F16
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
DataType
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
// INT32
using
ZDataType
=
U16
;
// INT32
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
DDataType
;
using
Acc0BiasDataType
=
F16
;
using
Acc1BiasDataType
=
void
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward.inc
View file @
c9915508
...
@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
...
@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
D
DataType
>>
d_tensors
;
std
::
vector
<
Tensor
<
Acc0Bias
DataType
>>
d_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
...
@@ -157,7 +157,7 @@ int run(int argc, char* argv[])
...
@@ -157,7 +157,7 @@ int run(int argc, char* argv[])
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
D
DataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
Acc0Bias
DataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
...
@@ -165,7 +165,7 @@ int run(int argc, char* argv[])
...
@@ -165,7 +165,7 @@ int run(int argc, char* argv[])
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
D
DataType
)
*
M
*
N
*
(
std
::
is_void
<
Acc0BiasDataType
>::
value
?
0
:
1
))
*
sizeof
(
Acc0Bias
DataType
)
*
M
*
N
*
(
std
::
is_void
<
Acc0BiasDataType
>::
value
?
0
:
1
))
*
Batch
;
Batch
;
if
(
i
<
4
)
if
(
i
<
4
)
...
@@ -189,25 +189,25 @@ int run(int argc, char* argv[])
...
@@ -189,25 +189,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
D
DataType
>
{
-
1
,
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0Bias
DataType
>
{
-
1
,
1
});
break
;
break
;
case
2
:
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
D
DataType
>
{
-
0.5
,
0.5
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0Bias
DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
case
3
:
case
3
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
break
;
break
;
default
:
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
D
DataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0Bias
DataType
>
{
1
});
}
}
a_tensors
.
push_back
(
a_gs_ms_ks
);
a_tensors
.
push_back
(
a_gs_ms_ks
);
...
@@ -227,7 +227,7 @@ int run(int argc, char* argv[])
...
@@ -227,7 +227,7 @@ int run(int argc, char* argv[])
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
()));
d_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
d_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
D
DataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
Acc0Bias
DataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
()));
z_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
z_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
()));
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
...
@@ -359,7 +359,7 @@ int run(int argc, char* argv[])
...
@@ -359,7 +359,7 @@ int run(int argc, char* argv[])
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
,
K
,
N
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc
0Bias
DataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment