Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f73c3ea2
Commit
f73c3ea2
authored
May 23, 2022
by
myamlak
Browse files
Single workspace for cgemm + helper
parent
4379d8d1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
58 deletions
+43
-58
example/21_cgemm/cgemm_xdl_bf16.cpp
example/21_cgemm/cgemm_xdl_bf16.cpp
+4
-9
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+7
-2
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+20
-12
test/cgemm/cgemm_util.hpp
test/cgemm/cgemm_util.hpp
+12
-35
No files found.
example/21_cgemm/cgemm_xdl_bf16.cpp
View file @
f73c3ea2
...
@@ -150,8 +150,6 @@ int main(int argc, char* argv[])
...
@@ -150,8 +150,6 @@ int main(int argc, char* argv[])
Tensor
<
BDataType
>
b_k_n_imag
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_imag
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_real_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_real_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_imag_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_imag_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux_2
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k_real: "
<<
a_m_k_real
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_real: "
<<
a_m_k_real
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_imag: "
<<
a_m_k_imag
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_imag: "
<<
a_m_k_imag
.
mDesc
<<
std
::
endl
;
...
@@ -159,8 +157,6 @@ int main(int argc, char* argv[])
...
@@ -159,8 +157,6 @@ int main(int argc, char* argv[])
std
::
cout
<<
"b_k_n_imag: "
<<
b_k_n_imag
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n_imag: "
<<
b_k_n_imag
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_real: "
<<
c_m_n_real_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_real: "
<<
c_m_n_real_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_imag: "
<<
c_m_n_imag_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_imag: "
<<
c_m_n_imag_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"aux: "
<<
aux
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"aux_2: "
<<
aux_2
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -178,6 +174,8 @@ int main(int argc, char* argv[])
...
@@ -178,6 +174,8 @@ int main(int argc, char* argv[])
b_k_n_imag
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_k_n_imag
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
}
auto
cgemm
=
DeviceCGemmInstance
{};
DeviceMem
a_m_k_real_device_buf
(
sizeof
(
ADataType
)
*
a_m_k_real
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_real_device_buf
(
sizeof
(
ADataType
)
*
a_m_k_real
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_imag_device_buf
(
sizeof
(
ADataType
)
*
a_m_k_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_imag_device_buf
(
sizeof
(
ADataType
)
*
a_m_k_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_real_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_real
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_real_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_real
.
mDesc
.
GetElementSpace
());
...
@@ -186,8 +184,7 @@ int main(int argc, char* argv[])
...
@@ -186,8 +184,7 @@ int main(int argc, char* argv[])
c_m_n_real_device_result
.
mDesc
.
GetElementSpace
());
c_m_n_real_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_imag_device_result
.
mDesc
.
GetElementSpace
());
c_m_n_imag_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
aux_device_buf
(
sizeof
(
CDataType
)
*
aux
.
mDesc
.
GetElementSpace
());
DeviceMem
workspace_device_buf
(
cgemm
.
GetWorkspaceSize
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
));
DeviceMem
aux_2_device_buf
(
sizeof
(
CDataType
)
*
aux_2
.
mDesc
.
GetElementSpace
());
a_m_k_real_device_buf
.
ToDevice
(
a_m_k_real
.
mData
.
data
());
a_m_k_real_device_buf
.
ToDevice
(
a_m_k_real
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
a_m_k_imag
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
a_m_k_imag
.
mData
.
data
());
...
@@ -199,7 +196,6 @@ int main(int argc, char* argv[])
...
@@ -199,7 +196,6 @@ int main(int argc, char* argv[])
auto
c_element_op
=
PassThrough
{};
auto
c_element_op
=
PassThrough
{};
// do GEMM
// do GEMM
auto
cgemm
=
DeviceCGemmInstance
{};
auto
invoker
=
cgemm
.
MakeInvoker
();
auto
invoker
=
cgemm
.
MakeInvoker
();
auto
argument
=
auto
argument
=
cgemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_real_device_buf
.
GetDeviceBuffer
()),
cgemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_real_device_buf
.
GetDeviceBuffer
()),
...
@@ -208,8 +204,7 @@ int main(int argc, char* argv[])
...
@@ -208,8 +204,7 @@ int main(int argc, char* argv[])
static_cast
<
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
workspace_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_2_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
N
,
N
,
K
,
K
,
...
...
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
View file @
f73c3ea2
...
@@ -19,8 +19,7 @@ struct DeviceCGemm : public BaseOperator
...
@@ -19,8 +19,7 @@ struct DeviceCGemm : public BaseOperator
const
void
*
p_b_imag
,
const
void
*
p_b_imag
,
void
*
p_c_real
,
void
*
p_c_real
,
void
*
p_c_imag
,
void
*
p_c_imag
,
void
*
p_aux
,
void
*
p_workspace
,
void
*
p_aux_2
,
ck
::
index_t
M
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
...
@@ -33,6 +32,12 @@ struct DeviceCGemm : public BaseOperator
...
@@ -33,6 +32,12 @@ struct DeviceCGemm : public BaseOperator
ck
::
index_t
KBatch
=
1
)
=
0
;
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
size_t
GetWorkspaceSize
(
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
)
=
0
;
};
};
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
f73c3ea2
...
@@ -427,8 +427,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -427,8 +427,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
const
BDataType
*
p_b_grid_imag
,
const
BDataType
*
p_b_grid_imag
,
CDataType
*
p_c_grid_real
,
CDataType
*
p_c_grid_real
,
CDataType
*
p_c_grid_imag
,
CDataType
*
p_c_grid_imag
,
CDataType
*
p_aux_grid
,
CDataType
*
p_workspace
,
CDataType
*
p_aux_2_grid
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
...
@@ -444,8 +443,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -444,8 +443,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_b_grid_imag_
{
p_b_grid_imag
},
p_b_grid_imag_
{
p_b_grid_imag
},
p_c_grid_real_
{
p_c_grid_real
},
p_c_grid_real_
{
p_c_grid_real
},
p_c_grid_imag_
{
p_c_grid_imag
},
p_c_grid_imag_
{
p_c_grid_imag
},
p_aux_grid_
{
p_aux_grid
},
p_aux_grid_
{
p_workspace
},
p_aux_2_grid_
{
p_aux_2_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
...
@@ -477,6 +475,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -477,6 +475,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
c_grid_desc_m0_
=
c_grid_desc_m0_
=
DeviceOp
::
MakeDescriptor_M0
({
MRaw
,
NRaw
},
{
I1
,
StrideC
},
grid_size
,
BlockSize
);
DeviceOp
::
MakeDescriptor_M0
({
MRaw
,
NRaw
},
{
I1
,
StrideC
},
grid_size
,
BlockSize
);
}
}
p_aux_2_grid_
=
p_workspace
+
c_grid_desc_m_n_
.
GetElementSpaceSize
();
}
}
// private:
// private:
...
@@ -812,8 +812,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -812,8 +812,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
const
BDataType
*
p_b_imag
,
const
BDataType
*
p_b_imag
,
CDataType
*
p_c_real
,
CDataType
*
p_c_real
,
CDataType
*
p_c_imag
,
CDataType
*
p_c_imag
,
CDataType
*
p_aux
,
CDataType
*
p_workspace
,
CDataType
*
p_aux_2
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
...
@@ -830,8 +829,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -830,8 +829,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_b_imag
,
p_b_imag
,
p_c_real
,
p_c_real
,
p_c_imag
,
p_c_imag
,
p_aux
,
p_workspace
,
p_aux_2
,
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
...
@@ -852,8 +850,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -852,8 +850,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
const
void
*
p_b_imag
,
const
void
*
p_b_imag
,
void
*
p_c_real
,
void
*
p_c_real
,
void
*
p_c_imag
,
void
*
p_c_imag
,
void
*
p_aux
,
void
*
p_workspace
,
void
*
p_aux_2
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
...
@@ -871,8 +868,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -871,8 +868,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static_cast
<
const
BDataType
*>
(
p_b_imag
),
static_cast
<
const
BDataType
*>
(
p_b_imag
),
static_cast
<
CDataType
*>
(
p_c_real
),
static_cast
<
CDataType
*>
(
p_c_real
),
static_cast
<
CDataType
*>
(
p_c_imag
),
static_cast
<
CDataType
*>
(
p_c_imag
),
static_cast
<
CDataType
*>
(
p_aux
),
static_cast
<
CDataType
*>
(
p_workspace
),
static_cast
<
CDataType
*>
(
p_aux_2
),
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
...
@@ -909,6 +905,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -909,6 +905,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return
str
.
str
();
return
str
.
str
();
}
}
std
::
size_t
GetWorkspaceSize
([[
maybe_unused
]]
index_t
MRaw
,
[[
maybe_unused
]]
index_t
NRaw
,
[[
maybe_unused
]]
index_t
KRaw
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideB
,
[[
maybe_unused
]]
index_t
StrideC
)
override
{
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
);
return
2
*
sizeof
(
CDataType
)
*
c_grid_desc_m_n
.
GetElementSpaceSize
();
}
};
};
}
// namespace device
}
// namespace device
...
...
test/cgemm/cgemm_util.hpp
View file @
f73c3ea2
...
@@ -72,8 +72,6 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
...
@@ -72,8 +72,6 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
const
Tensor
<
BDataType
>&
B_imag
,
const
Tensor
<
BDataType
>&
B_imag
,
Tensor
<
CDataType
>&
C_real
,
Tensor
<
CDataType
>&
C_real
,
Tensor
<
CDataType
>&
C_imag
,
Tensor
<
CDataType
>&
C_imag
,
Tensor
<
CDataType
>&
Aux
,
Tensor
<
CDataType
>&
Aux_2
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
...
@@ -84,8 +82,8 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
...
@@ -84,8 +82,8 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
DeviceMem
b_k_n_imag_device_buf
(
sizeof
(
BDataType
)
*
B_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_imag_device_buf
(
sizeof
(
BDataType
)
*
B_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
CDataType
)
*
C_real
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
CDataType
)
*
C_real
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
C_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
C_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
aux
_device_buf
(
sizeof
(
CDataType
)
*
Aux
.
mDesc
.
GetElementSpace
());
DeviceMem
workspace
_device_buf
(
cgemmPtr
->
GetWorkspaceSize
(
DeviceMem
aux_2_device_buf
(
sizeof
(
CDataType
)
*
Aux_2
.
mDesc
.
GetElementSpace
(
));
params
.
M
,
params
.
N
,
params
.
K
,
params
.
StrideA
,
params
.
StrideB
,
params
.
StrideC
));
a_m_k_real_device_buf
.
ToDevice
(
A_real
.
mData
.
data
());
a_m_k_real_device_buf
.
ToDevice
(
A_real
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
A_imag
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
A_imag
.
mData
.
data
());
...
@@ -100,8 +98,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
...
@@ -100,8 +98,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
static_cast
<
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
workspace_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
aux_2_device_buf
.
GetDeviceBuffer
()),
params
.
M
,
params
.
M
,
params
.
N
,
params
.
N
,
params
.
K
,
params
.
K
,
...
@@ -168,10 +165,6 @@ struct TestCGemm
...
@@ -168,10 +165,6 @@ struct TestCGemm
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_imag_device_result
(
Tensor
<
CDataType
>
c_m_n_imag_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
aux_2
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
auto
f_generate_tensor_value
=
[](
auto
&
tensor
,
auto
type
)
{
auto
f_generate_tensor_value
=
[](
auto
&
tensor
,
auto
type
)
{
using
dataType
=
decltype
(
type
);
using
dataType
=
decltype
(
type
);
...
@@ -191,9 +184,7 @@ struct TestCGemm
...
@@ -191,9 +184,7 @@ struct TestCGemm
c_m_n_real_host_result
,
c_m_n_real_host_result
,
c_m_n_imag_host_result
,
c_m_n_imag_host_result
,
c_m_n_real_device_result
,
c_m_n_real_device_result
,
c_m_n_imag_device_result
,
c_m_n_imag_device_result
);
aux
,
aux_2
);
}
}
auto
operator
()(
DeviceCGemmPtr_
&
cgemmPtr
)
auto
operator
()(
DeviceCGemmPtr_
&
cgemmPtr
)
...
@@ -221,8 +212,6 @@ struct TestCGemm
...
@@ -221,8 +212,6 @@ struct TestCGemm
Tensor
<
CDataType
>&
c_host_imag
=
std
::
get
<
5
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host_imag
=
std
::
get
<
5
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device_real
=
std
::
get
<
6
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device_real
=
std
::
get
<
6
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device_imag
=
std
::
get
<
7
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device_imag
=
std
::
get
<
7
>
(
host_tensors
);
Tensor
<
CDataType
>&
aux
=
std
::
get
<
8
>
(
host_tensors
);
Tensor
<
CDataType
>&
aux_2
=
std
::
get
<
9
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
...
@@ -254,8 +243,6 @@ struct TestCGemm
...
@@ -254,8 +243,6 @@ struct TestCGemm
b_imag
,
b_imag
,
c_device_real
,
c_device_real
,
c_device_imag
,
c_device_imag
,
aux
,
aux_2
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
...
@@ -340,10 +327,6 @@ struct TestCGemmBF16
...
@@ -340,10 +327,6 @@ struct TestCGemmBF16
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
BF16
>
c_m_n_imag_device_bf16
(
Tensor
<
BF16
>
c_m_n_imag_device_bf16
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
BF16
>
aux_bf16
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
BF16
>
aux_2_bf16
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
a_m_k_real_fp32
(
Tensor
<
float
>
a_m_k_real_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
...
@@ -378,8 +361,6 @@ struct TestCGemmBF16
...
@@ -378,8 +361,6 @@ struct TestCGemmBF16
b_k_n_imag_bf16
,
b_k_n_imag_bf16
,
c_m_n_real_device_bf16
,
c_m_n_real_device_bf16
,
c_m_n_imag_device_bf16
,
c_m_n_imag_device_bf16
,
aux_bf16
,
aux_2_bf16
,
a_m_k_real_fp32
,
a_m_k_real_fp32
,
a_m_k_imag_fp32
,
a_m_k_imag_fp32
,
b_k_n_real_fp32
,
b_k_n_real_fp32
,
...
@@ -408,16 +389,14 @@ struct TestCGemmBF16
...
@@ -408,16 +389,14 @@ struct TestCGemmBF16
const
Tensor
<
BF16
>&
b_imag_bf16
=
std
::
get
<
3
>
(
host_tensors
);
const
Tensor
<
BF16
>&
b_imag_bf16
=
std
::
get
<
3
>
(
host_tensors
);
Tensor
<
BF16
>&
c_real_device_bf16
=
std
::
get
<
4
>
(
host_tensors
);
Tensor
<
BF16
>&
c_real_device_bf16
=
std
::
get
<
4
>
(
host_tensors
);
Tensor
<
BF16
>&
c_imag_device_bf16
=
std
::
get
<
5
>
(
host_tensors
);
Tensor
<
BF16
>&
c_imag_device_bf16
=
std
::
get
<
5
>
(
host_tensors
);
Tensor
<
BF16
>&
aux_bf16
=
std
::
get
<
6
>
(
host_tensors
);
Tensor
<
float
>&
a_real_fp32
=
std
::
get
<
6
>
(
host_tensors
);
Tensor
<
BF16
>&
aux_2_bf16
=
std
::
get
<
7
>
(
host_tensors
);
Tensor
<
float
>&
a_imag_fp32
=
std
::
get
<
7
>
(
host_tensors
);
Tensor
<
float
>&
a_real_fp32
=
std
::
get
<
8
>
(
host_tensors
);
Tensor
<
float
>&
b_real_fp32
=
std
::
get
<
8
>
(
host_tensors
);
Tensor
<
float
>&
a_imag_fp32
=
std
::
get
<
9
>
(
host_tensors
);
Tensor
<
float
>&
b_imag_fp32
=
std
::
get
<
9
>
(
host_tensors
);
Tensor
<
float
>&
b_real_fp32
=
std
::
get
<
10
>
(
host_tensors
);
Tensor
<
float
>&
c_real_host_fp32
=
std
::
get
<
10
>
(
host_tensors
);
Tensor
<
float
>&
b_imag_fp32
=
std
::
get
<
11
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_host_fp32
=
std
::
get
<
11
>
(
host_tensors
);
Tensor
<
float
>&
c_real_host_fp32
=
std
::
get
<
12
>
(
host_tensors
);
Tensor
<
float
>&
c_real_device_fp32
=
std
::
get
<
12
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_host_fp32
=
std
::
get
<
13
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_device_fp32
=
std
::
get
<
13
>
(
host_tensors
);
Tensor
<
float
>&
c_real_device_fp32
=
std
::
get
<
14
>
(
host_tensors
);
Tensor
<
float
>&
c_imag_device_fp32
=
std
::
get
<
15
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
...
@@ -450,8 +429,6 @@ struct TestCGemmBF16
...
@@ -450,8 +429,6 @@ struct TestCGemmBF16
b_imag_bf16
,
b_imag_bf16
,
c_real_device_bf16
,
c_real_device_bf16
,
c_imag_device_bf16
,
c_imag_device_bf16
,
aux_bf16
,
aux_2_bf16
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment