Commit 5ae304df authored by myamlak's avatar myamlak
Browse files

Second auxiliary buffer added

parent b3767dbe
...@@ -151,6 +151,7 @@ int main(int argc, char* argv[]) ...@@ -151,6 +151,7 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> aux(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> aux(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> aux_2(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl; std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl;
std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl; std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl;
...@@ -159,6 +160,7 @@ int main(int argc, char* argv[]) ...@@ -159,6 +160,7 @@ int main(int argc, char* argv[])
std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl; std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl;
std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl; std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl;
std::cout << "aux: " << aux.mDesc << std::endl; std::cout << "aux: " << aux.mDesc << std::endl;
std::cout << "aux_2: " << aux_2.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -185,6 +187,7 @@ int main(int argc, char* argv[]) ...@@ -185,6 +187,7 @@ int main(int argc, char* argv[])
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
c_m_n_imag_device_result.mDesc.GetElementSpace()); c_m_n_imag_device_result.mDesc.GetElementSpace());
DeviceMem aux_device_buf(sizeof(CDataType) * aux.mDesc.GetElementSpace()); DeviceMem aux_device_buf(sizeof(CDataType) * aux.mDesc.GetElementSpace());
DeviceMem aux_2_device_buf(sizeof(CDataType) * aux_2.mDesc.GetElementSpace());
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data());
...@@ -206,6 +209,7 @@ int main(int argc, char* argv[]) ...@@ -206,6 +209,7 @@ int main(int argc, char* argv[])
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(aux_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_2_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
......
...@@ -20,6 +20,7 @@ struct DeviceCGemm : public BaseOperator ...@@ -20,6 +20,7 @@ struct DeviceCGemm : public BaseOperator
void* p_c_real, void* p_c_real,
void* p_c_imag, void* p_c_imag,
void* p_aux, void* p_aux,
void* p_aux_2,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
......
...@@ -390,6 +390,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -390,6 +390,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_real, CDataType* p_c_grid_real,
CDataType* p_c_grid_imag, CDataType* p_c_grid_imag,
CDataType* p_aux_grid, CDataType* p_aux_grid,
CDataType* p_aux_2_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -406,6 +407,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -406,6 +407,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_grid_real_{p_c_grid_real}, p_c_grid_real_{p_c_grid_real},
p_c_grid_imag_{p_c_grid_imag}, p_c_grid_imag_{p_c_grid_imag},
p_aux_grid_{p_aux_grid}, p_aux_grid_{p_aux_grid},
p_aux_2_grid_{p_aux_2_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
...@@ -434,6 +436,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -434,6 +436,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_real_; CDataType* p_c_grid_real_;
CDataType* p_c_grid_imag_; CDataType* p_c_grid_imag_;
CDataType* p_aux_grid_; CDataType* p_aux_grid_;
CDataType* p_aux_2_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -488,7 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -488,7 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0, 0,
arg.p_a_grid_real_, arg.p_a_grid_real_,
arg.p_b_grid_real_, arg.p_b_grid_real_,
arg.p_c_grid_real_, arg.p_aux_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -505,7 +508,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -505,7 +508,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0, 0,
arg.p_a_grid_imag_, arg.p_a_grid_imag_,
arg.p_b_grid_imag_, arg.p_b_grid_imag_,
arg.p_aux_grid_, arg.p_aux_2_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -514,7 +517,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -514,7 +517,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_real = c_real - aux needed here!!! // c_real = aux - aux_2 needed here!!!
ave_time += ave_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -524,7 +527,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -524,7 +527,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0, 0,
arg.p_a_grid_real_, arg.p_a_grid_real_,
arg.p_b_grid_imag_, arg.p_b_grid_imag_,
arg.p_c_grid_imag_, arg.p_aux_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -541,7 +544,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -541,7 +544,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0, 0,
arg.p_a_grid_imag_, arg.p_a_grid_imag_,
arg.p_b_grid_real_, arg.p_b_grid_real_,
arg.p_aux_grid_, arg.p_aux_2_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -550,7 +553,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -550,7 +553,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_imag = c_imag + aux needed here!!! // c_imag = aux + aux_2 needed here!!!
} }
else else
{ {
...@@ -575,7 +578,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -575,7 +578,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0, 0,
arg.p_a_grid_real_, arg.p_a_grid_real_,
arg.p_b_grid_real_, arg.p_b_grid_real_,
arg.p_c_grid_real_, arg.p_aux_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -592,7 +595,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -592,7 +595,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0, 0,
arg.p_a_grid_imag_, arg.p_a_grid_imag_,
arg.p_b_grid_imag_, arg.p_b_grid_imag_,
arg.p_aux_grid_, arg.p_aux_2_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -601,7 +604,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -601,7 +604,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// // c_real = c_real - aux needed here!!! // // c_real = aux - aux_2 needed here!!!
ave_time += ave_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -611,7 +614,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -611,7 +614,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0, 0,
arg.p_a_grid_real_, arg.p_a_grid_real_,
arg.p_b_grid_imag_, arg.p_b_grid_imag_,
arg.p_c_grid_imag_, arg.p_aux_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -628,7 +631,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -628,7 +631,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0, 0,
arg.p_a_grid_imag_, arg.p_a_grid_imag_,
arg.p_b_grid_real_, arg.p_b_grid_real_,
arg.p_aux_grid_, arg.p_aux_2_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -637,7 +640,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -637,7 +640,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
// c_imag = c_imag + aux needed here!!! // c_imag = aux + aux_2 needed here!!!
} }
return ave_time; return ave_time;
...@@ -676,6 +679,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -676,6 +679,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_real, CDataType* p_c_real,
CDataType* p_c_imag, CDataType* p_c_imag,
CDataType* p_aux, CDataType* p_aux,
CDataType* p_aux_2,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -693,6 +697,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -693,6 +697,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_real, p_c_real,
p_c_imag, p_c_imag,
p_aux, p_aux,
p_aux_2,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
...@@ -714,6 +719,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -714,6 +719,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
void* p_c_real, void* p_c_real,
void* p_c_imag, void* p_c_imag,
void* p_aux, void* p_aux,
void* p_aux_2,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -732,6 +738,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -732,6 +738,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static_cast<CDataType*>(p_c_real), static_cast<CDataType*>(p_c_real),
static_cast<CDataType*>(p_c_imag), static_cast<CDataType*>(p_c_imag),
static_cast<CDataType*>(p_aux), static_cast<CDataType*>(p_aux),
static_cast<CDataType*>(p_aux_2),
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
......
...@@ -73,6 +73,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, ...@@ -73,6 +73,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
Tensor<CDataType>& C_real, Tensor<CDataType>& C_real,
Tensor<CDataType>& C_imag, Tensor<CDataType>& C_imag,
Tensor<CDataType>& Aux, Tensor<CDataType>& Aux,
Tensor<CDataType>& Aux_2,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
...@@ -84,6 +85,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, ...@@ -84,6 +85,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace()); DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace()); DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace());
DeviceMem aux_device_buf(sizeof(CDataType) * Aux.mDesc.GetElementSpace()); DeviceMem aux_device_buf(sizeof(CDataType) * Aux.mDesc.GetElementSpace());
DeviceMem aux_2_device_buf(sizeof(CDataType) * Aux_2.mDesc.GetElementSpace());
a_m_k_real_device_buf.ToDevice(A_real.mData.data()); a_m_k_real_device_buf.ToDevice(A_real.mData.data());
a_m_k_imag_device_buf.ToDevice(A_imag.mData.data()); a_m_k_imag_device_buf.ToDevice(A_imag.mData.data());
...@@ -99,6 +101,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, ...@@ -99,6 +101,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(aux_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_2_device_buf.GetDeviceBuffer()),
params.M, params.M,
params.N, params.N,
params.K, params.K,
...@@ -167,6 +170,8 @@ struct TestCGemm ...@@ -167,6 +170,8 @@ struct TestCGemm
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> aux( Tensor<CDataType> aux(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> aux_2(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto& tensor, auto type) { auto f_generate_tensor_value = [](auto& tensor, auto type) {
using dataType = decltype(type); using dataType = decltype(type);
...@@ -187,7 +192,8 @@ struct TestCGemm ...@@ -187,7 +192,8 @@ struct TestCGemm
c_m_n_imag_host_result, c_m_n_imag_host_result,
c_m_n_real_device_result, c_m_n_real_device_result,
c_m_n_imag_device_result, c_m_n_imag_device_result,
aux); aux,
aux_2);
} }
auto operator()(DeviceCGemmPtr_& cgemmPtr) auto operator()(DeviceCGemmPtr_& cgemmPtr)
...@@ -216,6 +222,7 @@ struct TestCGemm ...@@ -216,6 +222,7 @@ struct TestCGemm
Tensor<CDataType>& c_device_real = std::get<6>(host_tensors); Tensor<CDataType>& c_device_real = std::get<6>(host_tensors);
Tensor<CDataType>& c_device_imag = std::get<7>(host_tensors); Tensor<CDataType>& c_device_imag = std::get<7>(host_tensors);
Tensor<CDataType>& aux = std::get<8>(host_tensors); Tensor<CDataType>& aux = std::get<8>(host_tensors);
Tensor<CDataType>& aux_2 = std::get<9>(host_tensors);
auto a_element_op = AElementwiseOperation{}; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{}; auto b_element_op = BElementwiseOperation{};
...@@ -248,6 +255,7 @@ struct TestCGemm ...@@ -248,6 +255,7 @@ struct TestCGemm
c_device_real, c_device_real,
c_device_imag, c_device_imag,
aux, aux,
aux_2,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
...@@ -319,6 +327,8 @@ struct TestCGemmBF16 ...@@ -319,6 +327,8 @@ struct TestCGemmBF16
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<BF16> aux_bf16( Tensor<BF16> aux_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<BF16> aux_2_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> a_m_k_real_fp32( Tensor<float> a_m_k_real_fp32(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
...@@ -354,6 +364,7 @@ struct TestCGemmBF16 ...@@ -354,6 +364,7 @@ struct TestCGemmBF16
c_m_n_real_device_bf16, c_m_n_real_device_bf16,
c_m_n_imag_device_bf16, c_m_n_imag_device_bf16,
aux_bf16, aux_bf16,
aux_2_bf16,
a_m_k_real_fp32, a_m_k_real_fp32,
a_m_k_imag_fp32, a_m_k_imag_fp32,
b_k_n_real_fp32, b_k_n_real_fp32,
...@@ -383,14 +394,15 @@ struct TestCGemmBF16 ...@@ -383,14 +394,15 @@ struct TestCGemmBF16
Tensor<BF16>& c_real_device_bf16 = std::get<4>(host_tensors); Tensor<BF16>& c_real_device_bf16 = std::get<4>(host_tensors);
Tensor<BF16>& c_imag_device_bf16 = std::get<5>(host_tensors); Tensor<BF16>& c_imag_device_bf16 = std::get<5>(host_tensors);
Tensor<BF16>& aux_bf16 = std::get<6>(host_tensors); Tensor<BF16>& aux_bf16 = std::get<6>(host_tensors);
Tensor<float>& a_real_fp32 = std::get<7>(host_tensors); Tensor<BF16>& aux_2_bf16 = std::get<7>(host_tensors);
Tensor<float>& a_imag_fp32 = std::get<8>(host_tensors); Tensor<float>& a_real_fp32 = std::get<8>(host_tensors);
Tensor<float>& b_real_fp32 = std::get<9>(host_tensors); Tensor<float>& a_imag_fp32 = std::get<9>(host_tensors);
Tensor<float>& b_imag_fp32 = std::get<10>(host_tensors); Tensor<float>& b_real_fp32 = std::get<10>(host_tensors);
Tensor<float>& c_real_host_fp32 = std::get<11>(host_tensors); Tensor<float>& b_imag_fp32 = std::get<11>(host_tensors);
Tensor<float>& c_imag_host_fp32 = std::get<12>(host_tensors); Tensor<float>& c_real_host_fp32 = std::get<12>(host_tensors);
Tensor<float>& c_real_device_fp32 = std::get<13>(host_tensors); Tensor<float>& c_imag_host_fp32 = std::get<13>(host_tensors);
Tensor<float>& c_imag_device_fp32 = std::get<14>(host_tensors); Tensor<float>& c_real_device_fp32 = std::get<14>(host_tensors);
Tensor<float>& c_imag_device_fp32 = std::get<15>(host_tensors);
auto a_element_op = AElementwiseOperation{}; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{}; auto b_element_op = BElementwiseOperation{};
...@@ -424,6 +436,7 @@ struct TestCGemmBF16 ...@@ -424,6 +436,7 @@ struct TestCGemmBF16
c_real_device_bf16, c_real_device_bf16,
c_imag_device_bf16, c_imag_device_bf16,
aux_bf16, aux_bf16,
aux_2_bf16,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment