Commit 5ae304df authored by myamlak's avatar myamlak
Browse files

Second auxiliary buffer added

parent b3767dbe
......@@ -151,6 +151,7 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> aux(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> aux_2(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl;
std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl;
......@@ -159,6 +160,7 @@ int main(int argc, char* argv[])
std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl;
std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl;
std::cout << "aux: " << aux.mDesc << std::endl;
std::cout << "aux_2: " << aux_2.mDesc << std::endl;
switch(init_method)
{
......@@ -185,6 +187,7 @@ int main(int argc, char* argv[])
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
c_m_n_imag_device_result.mDesc.GetElementSpace());
DeviceMem aux_device_buf(sizeof(CDataType) * aux.mDesc.GetElementSpace());
DeviceMem aux_2_device_buf(sizeof(CDataType) * aux_2.mDesc.GetElementSpace());
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data());
......@@ -206,6 +209,7 @@ int main(int argc, char* argv[])
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_2_device_buf.GetDeviceBuffer()),
M,
N,
K,
......
......@@ -20,6 +20,7 @@ struct DeviceCGemm : public BaseOperator
void* p_c_real,
void* p_c_imag,
void* p_aux,
void* p_aux_2,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......
......@@ -390,6 +390,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_real,
CDataType* p_c_grid_imag,
CDataType* p_aux_grid,
CDataType* p_aux_2_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -406,6 +407,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_grid_real_{p_c_grid_real},
p_c_grid_imag_{p_c_grid_imag},
p_aux_grid_{p_aux_grid},
p_aux_2_grid_{p_aux_2_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
......@@ -434,6 +436,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_grid_real_;
CDataType* p_c_grid_imag_;
CDataType* p_aux_grid_;
CDataType* p_aux_2_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
......@@ -488,7 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0,
arg.p_a_grid_real_,
arg.p_b_grid_real_,
arg.p_c_grid_real_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -505,7 +508,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0,
arg.p_a_grid_imag_,
arg.p_b_grid_imag_,
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -514,7 +517,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_real = c_real - aux needed here!!!
// c_real = aux - aux_2 needed here!!!
ave_time +=
launch_and_time_kernel(stream_config,
......@@ -524,7 +527,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0,
arg.p_a_grid_real_,
arg.p_b_grid_imag_,
arg.p_c_grid_imag_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -541,7 +544,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0,
arg.p_a_grid_imag_,
arg.p_b_grid_real_,
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -550,7 +553,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_imag = c_imag + aux needed here!!!
// c_imag = aux + aux_2 needed here!!!
}
else
{
......@@ -575,7 +578,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0,
arg.p_a_grid_real_,
arg.p_b_grid_real_,
arg.p_c_grid_real_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -592,7 +595,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0,
arg.p_a_grid_imag_,
arg.p_b_grid_imag_,
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -601,7 +604,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// // c_real = c_real - aux needed here!!!
// // c_real = aux - aux_2 needed here!!!
ave_time +=
launch_and_time_kernel(stream_config,
......@@ -611,7 +614,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0,
arg.p_a_grid_real_,
arg.p_b_grid_imag_,
arg.p_c_grid_imag_,
arg.p_aux_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -628,7 +631,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
0,
arg.p_a_grid_imag_,
arg.p_b_grid_real_,
arg.p_aux_grid_,
arg.p_aux_2_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -637,7 +640,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
// c_imag = c_imag + aux needed here!!!
// c_imag = aux + aux_2 needed here!!!
}
return ave_time;
......@@ -676,6 +679,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType* p_c_real,
CDataType* p_c_imag,
CDataType* p_aux,
CDataType* p_aux_2,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -693,6 +697,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_real,
p_c_imag,
p_aux,
p_aux_2,
MRaw,
NRaw,
KRaw,
......@@ -714,6 +719,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
void* p_c_real,
void* p_c_imag,
void* p_aux,
void* p_aux_2,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -732,6 +738,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static_cast<CDataType*>(p_c_real),
static_cast<CDataType*>(p_c_imag),
static_cast<CDataType*>(p_aux),
static_cast<CDataType*>(p_aux_2),
MRaw,
NRaw,
KRaw,
......
......@@ -73,6 +73,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
Tensor<CDataType>& C_real,
Tensor<CDataType>& C_imag,
Tensor<CDataType>& Aux,
Tensor<CDataType>& Aux_2,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
......@@ -84,6 +85,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace());
DeviceMem aux_device_buf(sizeof(CDataType) * Aux.mDesc.GetElementSpace());
DeviceMem aux_2_device_buf(sizeof(CDataType) * Aux_2.mDesc.GetElementSpace());
a_m_k_real_device_buf.ToDevice(A_real.mData.data());
a_m_k_imag_device_buf.ToDevice(A_imag.mData.data());
......@@ -99,6 +101,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_2_device_buf.GetDeviceBuffer()),
params.M,
params.N,
params.K,
......@@ -167,6 +170,8 @@ struct TestCGemm
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> aux(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> aux_2(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto& tensor, auto type) {
using dataType = decltype(type);
......@@ -187,7 +192,8 @@ struct TestCGemm
c_m_n_imag_host_result,
c_m_n_real_device_result,
c_m_n_imag_device_result,
aux);
aux,
aux_2);
}
auto operator()(DeviceCGemmPtr_& cgemmPtr)
......@@ -216,6 +222,7 @@ struct TestCGemm
Tensor<CDataType>& c_device_real = std::get<6>(host_tensors);
Tensor<CDataType>& c_device_imag = std::get<7>(host_tensors);
Tensor<CDataType>& aux = std::get<8>(host_tensors);
Tensor<CDataType>& aux_2 = std::get<9>(host_tensors);
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
......@@ -248,6 +255,7 @@ struct TestCGemm
c_device_real,
c_device_imag,
aux,
aux_2,
a_element_op,
b_element_op,
c_element_op);
......@@ -319,6 +327,8 @@ struct TestCGemmBF16
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<BF16> aux_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<BF16> aux_2_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> a_m_k_real_fp32(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
......@@ -354,6 +364,7 @@ struct TestCGemmBF16
c_m_n_real_device_bf16,
c_m_n_imag_device_bf16,
aux_bf16,
aux_2_bf16,
a_m_k_real_fp32,
a_m_k_imag_fp32,
b_k_n_real_fp32,
......@@ -383,14 +394,15 @@ struct TestCGemmBF16
Tensor<BF16>& c_real_device_bf16 = std::get<4>(host_tensors);
Tensor<BF16>& c_imag_device_bf16 = std::get<5>(host_tensors);
Tensor<BF16>& aux_bf16 = std::get<6>(host_tensors);
Tensor<float>& a_real_fp32 = std::get<7>(host_tensors);
Tensor<float>& a_imag_fp32 = std::get<8>(host_tensors);
Tensor<float>& b_real_fp32 = std::get<9>(host_tensors);
Tensor<float>& b_imag_fp32 = std::get<10>(host_tensors);
Tensor<float>& c_real_host_fp32 = std::get<11>(host_tensors);
Tensor<float>& c_imag_host_fp32 = std::get<12>(host_tensors);
Tensor<float>& c_real_device_fp32 = std::get<13>(host_tensors);
Tensor<float>& c_imag_device_fp32 = std::get<14>(host_tensors);
Tensor<BF16>& aux_2_bf16 = std::get<7>(host_tensors);
Tensor<float>& a_real_fp32 = std::get<8>(host_tensors);
Tensor<float>& a_imag_fp32 = std::get<9>(host_tensors);
Tensor<float>& b_real_fp32 = std::get<10>(host_tensors);
Tensor<float>& b_imag_fp32 = std::get<11>(host_tensors);
Tensor<float>& c_real_host_fp32 = std::get<12>(host_tensors);
Tensor<float>& c_imag_host_fp32 = std::get<13>(host_tensors);
Tensor<float>& c_real_device_fp32 = std::get<14>(host_tensors);
Tensor<float>& c_imag_device_fp32 = std::get<15>(host_tensors);
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
......@@ -424,6 +436,7 @@ struct TestCGemmBF16
c_real_device_bf16,
c_imag_device_bf16,
aux_bf16,
aux_2_bf16,
a_element_op,
b_element_op,
c_element_op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment