Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5ae304df
Commit
5ae304df
authored
May 17, 2022
by
myamlak
Browse files
Second auxiliary buffer added
parent
b3767dbe
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
21 deletions
+46
-21
example/20_cgemm/cgemm_xdl_bf16.cpp
example/20_cgemm/cgemm_xdl_bf16.cpp
+4
-0
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+1
-0
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+19
-12
test/cgemm/cgemm_util.hpp
test/cgemm/cgemm_util.hpp
+22
-9
No files found.
example/20_cgemm/cgemm_xdl_bf16.cpp
View file @
5ae304df
...
@@ -151,6 +151,7 @@ int main(int argc, char* argv[])
...
@@ -151,6 +151,7 @@ int main(int argc, char* argv[])
Tensor
<
CDataType
>
c_m_n_real_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_real_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_imag_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_imag_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux_2
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k_real: "
<<
a_m_k_real
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_real: "
<<
a_m_k_real
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_imag: "
<<
a_m_k_imag
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_imag: "
<<
a_m_k_imag
.
mDesc
<<
std
::
endl
;
...
@@ -159,6 +160,7 @@ int main(int argc, char* argv[])
...
@@ -159,6 +160,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"c_m_n_real: "
<<
c_m_n_real_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_real: "
<<
c_m_n_real_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_imag: "
<<
c_m_n_imag_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_imag: "
<<
c_m_n_imag_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"aux: "
<<
aux
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"aux: "
<<
aux
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"aux_2: "
<<
aux_2
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -185,6 +187,7 @@ int main(int argc, char* argv[])
...
@@ -185,6 +187,7 @@ int main(int argc, char* argv[])
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_imag_device_result
.
mDesc
.
GetElementSpace
());
c_m_n_imag_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_device_buf
(
sizeof
(
CDataType
)
*
aux
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_device_buf
(
sizeof
(
CDataType
)
*
aux
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_2_device_buf
(
sizeof
(
CDataType
)
*
aux_2
.
mDesc
.
GetElementSpace
());
a_m_k_real_device_buf
.
ToDevice
(
a_m_k_real
.
mData
.
data
());
a_m_k_real_device_buf
.
ToDevice
(
a_m_k_real
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
a_m_k_imag
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
a_m_k_imag
.
mData
.
data
());
...
@@ -206,6 +209,7 @@ int main(int argc, char* argv[])
...
@@ -206,6 +209,7 @@ int main(int argc, char* argv[])
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_2_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
N
,
N
,
K
,
K
,
...
...
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
View file @
5ae304df
...
@@ -20,6 +20,7 @@ struct DeviceCGemm : public BaseOperator
...
@@ -20,6 +20,7 @@ struct DeviceCGemm : public BaseOperator
void
*
p_c_real
,
void
*
p_c_real
,
void
*
p_c_imag
,
void
*
p_c_imag
,
void
*
p_aux
,
void
*
p_aux
,
void
*
p_aux_2
,
ck
::
index_t
M
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
...
...
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
5ae304df
...
@@ -390,6 +390,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -390,6 +390,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType
*
p_c_grid_real
,
CDataType
*
p_c_grid_real
,
CDataType
*
p_c_grid_imag
,
CDataType
*
p_c_grid_imag
,
CDataType
*
p_aux_grid
,
CDataType
*
p_aux_grid
,
CDataType
*
p_aux_2_grid
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
...
@@ -406,6 +407,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -406,6 +407,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_grid_real_
{
p_c_grid_real
},
p_c_grid_real_
{
p_c_grid_real
},
p_c_grid_imag_
{
p_c_grid_imag
},
p_c_grid_imag_
{
p_c_grid_imag
},
p_aux_grid_
{
p_aux_grid
},
p_aux_grid_
{
p_aux_grid
},
p_aux_2_grid_
{
p_aux_2_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
...
@@ -434,6 +436,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -434,6 +436,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType
*
p_c_grid_real_
;
CDataType
*
p_c_grid_real_
;
CDataType
*
p_c_grid_imag_
;
CDataType
*
p_c_grid_imag_
;
CDataType
*
p_aux_grid_
;
CDataType
*
p_aux_grid_
;
CDataType
*
p_aux_2_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
...
@@ -488,7 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -488,7 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_real_
,
arg
.
p_a_grid_real_
,
arg
.
p_b_grid_real_
,
arg
.
p_b_grid_real_
,
arg
.
p_
c
_grid_
real_
,
arg
.
p_
aux
_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -505,7 +508,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -505,7 +508,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_imag_
,
arg
.
p_a_grid_imag_
,
arg
.
p_b_grid_imag_
,
arg
.
p_b_grid_imag_
,
arg
.
p_aux_grid_
,
arg
.
p_aux_
2_
grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -514,7 +517,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -514,7 +517,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// c_real =
c_real
- aux needed here!!!
// c_real =
aux
- aux
_2
needed here!!!
ave_time
+=
ave_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -524,7 +527,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -524,7 +527,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_real_
,
arg
.
p_a_grid_real_
,
arg
.
p_b_grid_imag_
,
arg
.
p_b_grid_imag_
,
arg
.
p_
c
_grid_
imag_
,
arg
.
p_
aux
_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -541,7 +544,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -541,7 +544,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_imag_
,
arg
.
p_a_grid_imag_
,
arg
.
p_b_grid_real_
,
arg
.
p_b_grid_real_
,
arg
.
p_aux_grid_
,
arg
.
p_aux_
2_
grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -550,7 +553,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -550,7 +553,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// c_imag =
c_imag
+ aux needed here!!!
// c_imag =
aux
+ aux
_2
needed here!!!
}
}
else
else
{
{
...
@@ -575,7 +578,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -575,7 +578,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_real_
,
arg
.
p_a_grid_real_
,
arg
.
p_b_grid_real_
,
arg
.
p_b_grid_real_
,
arg
.
p_
c
_grid_
real_
,
arg
.
p_
aux
_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -592,7 +595,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -592,7 +595,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_imag_
,
arg
.
p_a_grid_imag_
,
arg
.
p_b_grid_imag_
,
arg
.
p_b_grid_imag_
,
arg
.
p_aux_grid_
,
arg
.
p_aux_
2_
grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -601,7 +604,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -601,7 +604,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// // c_real =
c_real
- aux needed here!!!
// // c_real =
aux
- aux
_2
needed here!!!
ave_time
+=
ave_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -611,7 +614,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -611,7 +614,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_real_
,
arg
.
p_a_grid_real_
,
arg
.
p_b_grid_imag_
,
arg
.
p_b_grid_imag_
,
arg
.
p_
c
_grid_
imag_
,
arg
.
p_
aux
_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -628,7 +631,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -628,7 +631,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0
,
0
,
arg
.
p_a_grid_imag_
,
arg
.
p_a_grid_imag_
,
arg
.
p_b_grid_real_
,
arg
.
p_b_grid_real_
,
arg
.
p_aux_grid_
,
arg
.
p_aux_
2_
grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -637,7 +640,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -637,7 +640,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
// c_imag =
c_imag
+ aux needed here!!!
// c_imag =
aux
+ aux
_2
needed here!!!
}
}
return
ave_time
;
return
ave_time
;
...
@@ -676,6 +679,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -676,6 +679,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType
*
p_c_real
,
CDataType
*
p_c_real
,
CDataType
*
p_c_imag
,
CDataType
*
p_c_imag
,
CDataType
*
p_aux
,
CDataType
*
p_aux
,
CDataType
*
p_aux_2
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
...
@@ -693,6 +697,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -693,6 +697,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_real
,
p_c_real
,
p_c_imag
,
p_c_imag
,
p_aux
,
p_aux
,
p_aux_2
,
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
...
@@ -714,6 +719,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -714,6 +719,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
void
*
p_c_real
,
void
*
p_c_real
,
void
*
p_c_imag
,
void
*
p_c_imag
,
void
*
p_aux
,
void
*
p_aux
,
void
*
p_aux_2
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
...
@@ -732,6 +738,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -732,6 +738,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static_cast
<
CDataType
*>
(
p_c_real
),
static_cast
<
CDataType
*>
(
p_c_real
),
static_cast
<
CDataType
*>
(
p_c_imag
),
static_cast
<
CDataType
*>
(
p_c_imag
),
static_cast
<
CDataType
*>
(
p_aux
),
static_cast
<
CDataType
*>
(
p_aux
),
static_cast
<
CDataType
*>
(
p_aux_2
),
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
...
...
test/cgemm/cgemm_util.hpp
View file @
5ae304df
...
@@ -73,6 +73,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
...
@@ -73,6 +73,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
Tensor
<
CDataType
>&
C_real
,
Tensor
<
CDataType
>&
C_real
,
Tensor
<
CDataType
>&
C_imag
,
Tensor
<
CDataType
>&
C_imag
,
Tensor
<
CDataType
>&
Aux
,
Tensor
<
CDataType
>&
Aux
,
Tensor
<
CDataType
>&
Aux_2
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
...
@@ -84,6 +85,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
...
@@ -84,6 +85,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
CDataType
)
*
C_real
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
CDataType
)
*
C_real
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
C_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
C_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_device_buf
(
sizeof
(
CDataType
)
*
Aux
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_device_buf
(
sizeof
(
CDataType
)
*
Aux
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_2_device_buf
(
sizeof
(
CDataType
)
*
Aux_2
.
mDesc
.
GetElementSpace
());
a_m_k_real_device_buf
.
ToDevice
(
A_real
.
mData
.
data
());
a_m_k_real_device_buf
.
ToDevice
(
A_real
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
A_imag
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
A_imag
.
mData
.
data
());
...
@@ -99,6 +101,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
...
@@ -99,6 +101,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_2_device_buf
.
GetDeviceBuffer
()),
params
.
M
,
params
.
M
,
params
.
N
,
params
.
N
,
params
.
K
,
params
.
K
,
...
@@ -167,6 +170,8 @@ struct TestCGemm
...
@@ -167,6 +170,8 @@ struct TestCGemm
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux
(
Tensor
<
CDataType
>
aux
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux_2
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
auto
f_generate_tensor_value
=
[](
auto
&
tensor
,
auto
type
)
{
auto
f_generate_tensor_value
=
[](
auto
&
tensor
,
auto
type
)
{
using
dataType
=
decltype
(
type
);
using
dataType
=
decltype
(
type
);
...
@@ -187,7 +192,8 @@ struct TestCGemm
...
@@ -187,7 +192,8 @@ struct TestCGemm
c_m_n_imag_host_result
,
c_m_n_imag_host_result
,
c_m_n_real_device_result
,
c_m_n_real_device_result
,
c_m_n_imag_device_result
,
c_m_n_imag_device_result
,
aux
);
aux
,
aux_2
);
}
}
auto
operator
()(
DeviceCGemmPtr_
&
cgemmPtr
)
auto
operator
()(
DeviceCGemmPtr_
&
cgemmPtr
)
...
@@ -216,6 +222,7 @@ struct TestCGemm
...
@@ -216,6 +222,7 @@ struct TestCGemm
Tensor
<
CDataType
>&
c_device_real
=
std
::
get
<
6
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device_real
=
std
::
get
<
6
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device_imag
=
std
::
get
<
7
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device_imag
=
std
::
get
<
7
>
(
host_tensors
);
Tensor
<
CDataType
>&
aux
=
std
::
get
<
8
>
(
host_tensors
);
Tensor
<
CDataType
>&
aux
=
std
::
get
<
8
>
(
host_tensors
);
Tensor
<
CDataType
>&
aux_2
=
std
::
get
<
9
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
...
@@ -248,6 +255,7 @@ struct TestCGemm
...
@@ -248,6 +255,7 @@ struct TestCGemm
c_device_real
,
c_device_real
,
c_device_imag
,
c_device_imag
,
aux
,
aux
,
aux_2
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
...
@@ -319,6 +327,8 @@ struct TestCGemmBF16
...
@@ -319,6 +327,8 @@ struct TestCGemmBF16
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
BF16
>
aux_bf16
(
Tensor
<
BF16
>
aux_bf16
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
BF16
>
aux_2_bf16
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
a_m_k_real_fp32
(
Tensor
<
float
>
a_m_k_real_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
...
@@ -354,6 +364,7 @@ struct TestCGemmBF16
...
@@ -354,6 +364,7 @@ struct TestCGemmBF16
c_m_n_real_device_bf16
,
c_m_n_real_device_bf16
,
c_m_n_imag_device_bf16
,
c_m_n_imag_device_bf16
,
aux_bf16
,
aux_bf16
,
aux_2_bf16
,
a_m_k_real_fp32
,
a_m_k_real_fp32
,
a_m_k_imag_fp32
,
a_m_k_imag_fp32
,
b_k_n_real_fp32
,
b_k_n_real_fp32
,
...
@@ -383,14 +394,15 @@ struct TestCGemmBF16
...
@@ -383,14 +394,15 @@ struct TestCGemmBF16
Tensor
<
BF16
>&
c_real_device_bf16
=
std
::
get
<
4
>
(
host_tensors
);
Tensor
<
BF16
>&
c_real_device_bf16
=
std
::
get
<
4
>
(
host_tensors
);
Tensor
<
BF16
>&
c_imag_device_bf16
=
std
::
get
<
5
>
(
host_tensors
);
Tensor
<
BF16
>&
c_imag_device_bf16
=
std
::
get
<
5
>
(
host_tensors
);
Tensor
<
BF16
>&
aux_bf16
=
std
::
get
<
6
>
(
host_tensors
);
Tensor
<
BF16
>&
aux_bf16
=
std
::
get
<
6
>
(
host_tensors
);
Tensor
<
float
>&
a_real_fp32
=
std
::
get
<
7
>
(
host_tensors
);
Tensor
<
BF16
>&
aux_2_bf16
=
std
::
get
<
7
>
(
host_tensors
);
Tensor
<
float
>&
a_imag_fp32
=
std
::
get
<
8
>
(
host_tensors
);
Tensor
<
float
>&
a_real_fp32
=
std
::
get
<
8
>
(
host_tensors
);
Tensor
<
float
>&
b_real_fp32
=
std
::
get
<
9
>
(
host_tensors
);
Tensor
<
float
>&
a_imag_fp32
=
std
::
get
<
9
>
(
host_tensors
);
Tensor
<
float
>&
b_imag_fp32
=
std
::
get
<
10
>
(
host_tensors
);
Tensor
<
float
>&
b_real_fp32
=
std
::
get
<
10
>
(
host_tensors
);
Tensor
<
float
>&
c_real_host_fp32
=
std
::
get
<
11
>
(
host_tensors
);
Tensor
<
float
>&
b_imag_fp32
=
std
::
get
<
11
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_host_fp32
=
std
::
get
<
12
>
(
host_tensors
);
Tensor
<
float
>&
c_real_host_fp32
=
std
::
get
<
12
>
(
host_tensors
);
Tensor
<
float
>&
c_real_device_fp32
=
std
::
get
<
13
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_host_fp32
=
std
::
get
<
13
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_device_fp32
=
std
::
get
<
14
>
(
host_tensors
);
Tensor
<
float
>&
c_real_device_fp32
=
std
::
get
<
14
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_device_fp32
=
std
::
get
<
15
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
...
@@ -424,6 +436,7 @@ struct TestCGemmBF16
...
@@ -424,6 +436,7 @@ struct TestCGemmBF16
c_real_device_bf16
,
c_real_device_bf16
,
c_imag_device_bf16
,
c_imag_device_bf16
,
aux_bf16
,
aux_bf16
,
aux_2_bf16
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment