Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
674f74ad
Commit
674f74ad
authored
May 16, 2022
by
myamlak
Browse files
Test fixes.
parent
14bd1430
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
89 deletions
+55
-89
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
...ibrary/reference_tensor_operation/cpu/reference_cgemm.hpp
+4
-4
test/cgemm/cgemm_bf16.cpp
test/cgemm/cgemm_bf16.cpp
+6
-6
test/cgemm/cgemm_fp16.cpp
test/cgemm/cgemm_fp16.cpp
+2
-19
test/cgemm/cgemm_fp32.cpp
test/cgemm/cgemm_fp32.cpp
+2
-17
test/cgemm/cgemm_util.hpp
test/cgemm/cgemm_util.hpp
+41
-43
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
View file @
674f74ad
...
@@ -60,7 +60,7 @@ struct ReferenceCGemm : public device::BaseOperator
...
@@ -60,7 +60,7 @@ struct ReferenceCGemm : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
auto
f_mk_kn_mn_real
=
[
&
](
auto
m
,
auto
n
)
{
auto
f_mk_kn_mn_real
=
[
&
](
auto
m
,
auto
n
)
{
const
in
t
K
=
arg
.
a_m_k_real_
.
mDesc
.
GetLengths
()[
1
];
const
std
::
size_
t
K
=
arg
.
a_m_k_real_
.
mDesc
.
GetLengths
()[
1
];
if
(
K
!=
arg
.
a_m_k_imag_
.
mDesc
.
GetLengths
()[
1
])
if
(
K
!=
arg
.
a_m_k_imag_
.
mDesc
.
GetLengths
()[
1
])
{
{
...
@@ -69,7 +69,7 @@ struct ReferenceCGemm : public device::BaseOperator
...
@@ -69,7 +69,7 @@ struct ReferenceCGemm : public device::BaseOperator
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
in
t
k
=
0
;
k
<
K
;
++
k
)
for
(
std
::
size_
t
k
=
0
;
k
<
K
;
++
k
)
{
{
float
v_a_real
;
float
v_a_real
;
float
v_b_real
;
float
v_b_real
;
...
@@ -92,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator
...
@@ -92,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator
};
};
auto
f_mk_kn_mn_imag
=
[
&
](
auto
m
,
auto
n
)
{
auto
f_mk_kn_mn_imag
=
[
&
](
auto
m
,
auto
n
)
{
const
in
t
K
=
arg
.
a_m_k_real_
.
mDesc
.
GetLengths
()[
1
];
const
std
::
size_
t
K
=
arg
.
a_m_k_real_
.
mDesc
.
GetLengths
()[
1
];
if
(
K
!=
arg
.
a_m_k_imag_
.
mDesc
.
GetLengths
()[
1
])
if
(
K
!=
arg
.
a_m_k_imag_
.
mDesc
.
GetLengths
()[
1
])
{
{
...
@@ -101,7 +101,7 @@ struct ReferenceCGemm : public device::BaseOperator
...
@@ -101,7 +101,7 @@ struct ReferenceCGemm : public device::BaseOperator
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
in
t
k
=
0
;
k
<
K
;
++
k
)
for
(
std
::
size_
t
k
=
0
;
k
<
K
;
++
k
)
{
{
float
v_a_real
;
float
v_a_real
;
float
v_b_real
;
float
v_b_real
;
...
...
test/cgemm/cgemm_bf16.cpp
View file @
674f74ad
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceCGemmNoOpPtr
=
using
DeviceCGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
Device
C
GemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
...
@@ -48,9 +48,9 @@ int main()
...
@@ -48,9 +48,9 @@ int main()
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
bool
res
=
true
;
std
::
vector
<
DeviceCGemmNoOpPtr
>
gemmPtrs
;
std
::
vector
<
DeviceCGemmNoOpPtr
>
c
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
ck
::
tensor_operation
::
device
::
device_
c
gemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
cgemmPtrs
);
for
(
auto
&
cgemmPtr
:
cgemmPtrs
)
for
(
auto
&
cgemmPtr
:
cgemmPtrs
)
...
@@ -76,7 +76,7 @@ int main()
...
@@ -76,7 +76,7 @@ int main()
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
PassThrough
>
{}(
c
gemmPtr
);
}
}
cgemmPtrs
.
clear
();
cgemmPtrs
.
clear
();
...
...
test/cgemm/cgemm_fp16.cpp
View file @
674f74ad
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceCGemmNoOpPtr
=
using
DeviceCGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
Device
cg
emmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
Device
CG
emmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
...
@@ -50,10 +50,7 @@ int main()
...
@@ -50,10 +50,7 @@ int main()
bool
res
=
true
;
bool
res
=
true
;
std
::
vector
<
DeviceCGemmNoOpPtr
>
cgemmPtrs
;
std
::
vector
<
DeviceCGemmNoOpPtr
>
cgemmPtrs
;
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
cgemmPtrs
);
...
@@ -72,10 +69,6 @@ int main()
...
@@ -72,10 +69,6 @@ int main()
}
}
cgemmPtrs
.
clear
();
cgemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
cgemmPtrs
);
...
@@ -94,10 +87,6 @@ int main()
...
@@ -94,10 +87,6 @@ int main()
}
}
cgemmPtrs
.
clear
();
cgemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
cgemmPtrs
);
...
@@ -116,14 +105,8 @@ int main()
...
@@ -116,14 +105,8 @@ int main()
}
}
cgemmPtrs
.
clear
();
cgemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
cgemmPtrs
);
for
(
auto
&
cgemmPtr
:
cgemmPtrs
)
for
(
auto
&
cgemmPtr
:
cgemmPtrs
)
{
{
...
...
test/cgemm/cgemm_fp32.cpp
View file @
674f74ad
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceCGemmNoOpPtr
=
using
DeviceCGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
Device
cg
emmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
Device
CG
emmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
...
@@ -54,10 +54,7 @@ int main()
...
@@ -54,10 +54,7 @@ int main()
bool
res
=
true
;
bool
res
=
true
;
std
::
vector
<
DeviceCGemmNoOpPtr
>
cgemmPtrs
;
std
::
vector
<
DeviceCGemmNoOpPtr
>
cgemmPtrs
;
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
cgemmPtrs
);
...
@@ -76,10 +73,6 @@ int main()
...
@@ -76,10 +73,6 @@ int main()
}
}
cgemmPtrs
.
clear
();
cgemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
cgemmPtrs
);
...
@@ -98,10 +91,6 @@ int main()
...
@@ -98,10 +91,6 @@ int main()
}
}
cgemmPtrs
.
clear
();
cgemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
cgemmPtrs
);
...
@@ -120,10 +109,6 @@ int main()
...
@@ -120,10 +109,6 @@ int main()
}
}
cgemmPtrs
.
clear
();
cgemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
cgemmPtrs
);
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
ck
::
tensor_operation
::
device
::
device_cgemm_instance
::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
cgemmPtrs
);
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
cgemmPtrs
);
...
...
test/cgemm/cgemm_util.hpp
View file @
674f74ad
...
@@ -77,21 +77,23 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
...
@@ -77,21 +77,23 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
DeviceMem
a_m_k_real_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_real_device_buf
(
sizeof
(
ADataType
)
*
A_real
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_imag_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_imag_device_buf
(
sizeof
(
ADataType
)
*
A_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_real_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_real_device_buf
(
sizeof
(
BDataType
)
*
B_real
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_imag_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_imag_device_buf
(
sizeof
(
BDataType
)
*
B_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
CDataType
)
*
C_real
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
C_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_device_buf
(
sizeof
(
CDataType
)
*
Aux
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
a_m_k_real_device_buf
.
ToDevice
(
A_real
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
A_imag
.
mData
.
data
());
b_k_n_real_device_buf
.
ToDevice
(
B_real
.
mData
.
data
());
b_k_n_imag_device_buf
.
ToDevice
(
B_imag
.
mData
.
data
());
auto
invoker_ptr
=
cgemmPtr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
cgemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
cgemmPtr
->
MakeArgumentPointer
(
auto
argument_ptr
=
cgemmPtr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_m_k_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_
real
_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_
imag
_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
...
@@ -255,7 +257,7 @@ struct TestCGemm
...
@@ -255,7 +257,7 @@ struct TestCGemm
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
{
res
=
ck
::
utils
::
check_err
(
c_device_real
.
mData
,
c_host_real
.
mData
)
&&
res
=
ck
::
utils
::
check_err
(
c_device_real
.
mData
,
c_host_real
.
mData
)
&&
ck
::
utils
::
check_err
(
c_device_
real
.
mData
,
c_host
.
mData
);
ck
::
utils
::
check_err
(
c_device_
imag
.
mData
,
c_host
_imag
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
...
@@ -326,15 +328,13 @@ struct TestCGemmBF16
...
@@ -326,15 +328,13 @@ struct TestCGemmBF16
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
float
>
b_k_n_imag_fp32
(
Tensor
<
float
>
b_k_n_imag_fp32
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
float
>
c_m_n_
host_real
_fp32
(
Tensor
<
float
>
c_m_n_
real_host
_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
c_m_n_
host_imag
_fp32
(
Tensor
<
float
>
c_m_n_
imag_host
_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
c_m_n_device_
real_
fp32
(
Tensor
<
float
>
c_m_n_
real_
device_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
c_m_n_device_imag_fp32
(
Tensor
<
float
>
c_m_n_imag_device_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
aux_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
a_m_k_real_bf16
.
GenerateTensorValue
(
GeneratorTensor_3
<
BF16
>
{
-
0.5
,
0.5
});
a_m_k_real_bf16
.
GenerateTensorValue
(
GeneratorTensor_3
<
BF16
>
{
-
0.5
,
0.5
});
...
@@ -361,8 +361,7 @@ struct TestCGemmBF16
...
@@ -361,8 +361,7 @@ struct TestCGemmBF16
c_m_n_real_host_fp32
,
c_m_n_real_host_fp32
,
c_m_n_imag_host_fp32
,
c_m_n_imag_host_fp32
,
c_m_n_real_device_fp32
,
c_m_n_real_device_fp32
,
c_m_n_imag_device_fp32
,
c_m_n_imag_device_fp32
);
aux_fp32
);
}
}
auto
operator
()(
DeviceCGemmPtr_
&
cgemmPtr
)
auto
operator
()(
DeviceCGemmPtr_
&
cgemmPtr
)
...
@@ -392,32 +391,31 @@ struct TestCGemmBF16
...
@@ -392,32 +391,31 @@ struct TestCGemmBF16
Tensor
<
float
>&
c_imag_host_fp32
=
std
::
get
<
12
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_host_fp32
=
std
::
get
<
12
>
(
host_tensors
);
Tensor
<
float
>&
c_real_device_fp32
=
std
::
get
<
13
>
(
host_tensors
);
Tensor
<
float
>&
c_real_device_fp32
=
std
::
get
<
13
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_device_fp32
=
std
::
get
<
14
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_device_fp32
=
std
::
get
<
14
>
(
host_tensors
);
Tensor
<
float
>&
aux_fp32
=
std
::
get
<
15
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
// use fp32 host kernel to verify bf16 device kernel
// use fp32 host kernel to verify bf16 device kernel
using
ReferenceGemmInstance
=
using
Reference
C
GemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceCGemm
<
float
,
ck
::
tensor_operation
::
host
::
ReferenceCGemm
<
float
,
float
,
float
,
float
,
float
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
CElementwiseOperation
>
;
ck
::
gemm_util
::
RunHostCGEMM
<
ReferenceCGemmInstance
>
(
a_real_fp32
,
ck
::
c
gemm_util
::
RunHostCGEMM
<
ReferenceCGemmInstance
>
(
a_real_fp32
,
a_imag_fp32
,
a_imag_fp32
,
b_real_fp32
,
b_real_fp32
,
b_imag_fp32
,
b_imag_fp32
,
c_real_host_fp32
,
c_real_host_fp32
,
c_imag_fp32
,
c_imag_
host_
fp32
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
// Act
// Act
ck
::
gemm_util
::
RunDeviceCGEMM
(
cgemmPtr
,
ck
::
c
gemm_util
::
RunDeviceCGEMM
(
cgemmPtr
,
params
,
params
,
a_real_bf16
,
a_real_bf16
,
a_imag_bf16
,
a_imag_bf16
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment