Commit f73c3ea2 authored by myamlak's avatar myamlak
Browse files

Single workspace for cgemm + helper

parent 4379d8d1
...@@ -150,8 +150,6 @@ int main(int argc, char* argv[]) ...@@ -150,8 +150,6 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> aux(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> aux_2(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl; std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl;
std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl; std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl;
...@@ -159,8 +157,6 @@ int main(int argc, char* argv[]) ...@@ -159,8 +157,6 @@ int main(int argc, char* argv[])
std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl; std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl;
std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl; std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl;
std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl; std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl;
std::cout << "aux: " << aux.mDesc << std::endl;
std::cout << "aux_2: " << aux_2.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -178,6 +174,8 @@ int main(int argc, char* argv[]) ...@@ -178,6 +174,8 @@ int main(int argc, char* argv[])
b_k_n_imag.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n_imag.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
} }
auto cgemm = DeviceCGemmInstance{};
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace()); DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace());
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace()); DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace());
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace()); DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace());
...@@ -186,8 +184,7 @@ int main(int argc, char* argv[]) ...@@ -186,8 +184,7 @@ int main(int argc, char* argv[])
c_m_n_real_device_result.mDesc.GetElementSpace()); c_m_n_real_device_result.mDesc.GetElementSpace());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
c_m_n_imag_device_result.mDesc.GetElementSpace()); c_m_n_imag_device_result.mDesc.GetElementSpace());
DeviceMem aux_device_buf(sizeof(CDataType) * aux.mDesc.GetElementSpace()); DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC));
DeviceMem aux_2_device_buf(sizeof(CDataType) * aux_2.mDesc.GetElementSpace());
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data());
...@@ -199,7 +196,6 @@ int main(int argc, char* argv[]) ...@@ -199,7 +196,6 @@ int main(int argc, char* argv[])
auto c_element_op = PassThrough{}; auto c_element_op = PassThrough{};
// do GEMM // do GEMM
auto cgemm = DeviceCGemmInstance{};
auto invoker = cgemm.MakeInvoker(); auto invoker = cgemm.MakeInvoker();
auto argument = auto argument =
cgemm.MakeArgument(static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()), cgemm.MakeArgument(static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
...@@ -208,8 +204,7 @@ int main(int argc, char* argv[]) ...@@ -208,8 +204,7 @@ int main(int argc, char* argv[])
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(workspace_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_2_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
......
...@@ -19,8 +19,7 @@ struct DeviceCGemm : public BaseOperator ...@@ -19,8 +19,7 @@ struct DeviceCGemm : public BaseOperator
const void* p_b_imag, const void* p_b_imag,
void* p_c_real, void* p_c_real,
void* p_c_imag, void* p_c_imag,
void* p_aux, void* p_workspace,
void* p_aux_2,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
...@@ -33,6 +32,12 @@ struct DeviceCGemm : public BaseOperator ...@@ -33,6 +32,12 @@ struct DeviceCGemm : public BaseOperator
ck::index_t KBatch = 1) = 0; ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::size_t GetWorkspaceSize(index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC) = 0;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
......
...@@ -427,8 +427,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -427,8 +427,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
const BDataType* p_b_grid_imag, const BDataType* p_b_grid_imag,
CDataType* p_c_grid_real, CDataType* p_c_grid_real,
CDataType* p_c_grid_imag, CDataType* p_c_grid_imag,
CDataType* p_aux_grid, CDataType* p_workspace,
CDataType* p_aux_2_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -444,8 +443,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -444,8 +443,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_b_grid_imag_{p_b_grid_imag}, p_b_grid_imag_{p_b_grid_imag},
p_c_grid_real_{p_c_grid_real}, p_c_grid_real_{p_c_grid_real},
p_c_grid_imag_{p_c_grid_imag}, p_c_grid_imag_{p_c_grid_imag},
p_aux_grid_{p_aux_grid}, p_aux_grid_{p_workspace},
p_aux_2_grid_{p_aux_2_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
...@@ -477,6 +475,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -477,6 +475,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
c_grid_desc_m0_ = c_grid_desc_m0_ =
DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
} }
p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize();
} }
// private: // private:
...@@ -812,8 +812,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -812,8 +812,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
const BDataType* p_b_imag, const BDataType* p_b_imag,
CDataType* p_c_real, CDataType* p_c_real,
CDataType* p_c_imag, CDataType* p_c_imag,
CDataType* p_aux, CDataType* p_workspace,
CDataType* p_aux_2,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -830,8 +829,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -830,8 +829,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_b_imag, p_b_imag,
p_c_real, p_c_real,
p_c_imag, p_c_imag,
p_aux, p_workspace,
p_aux_2,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
...@@ -852,8 +850,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -852,8 +850,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
const void* p_b_imag, const void* p_b_imag,
void* p_c_real, void* p_c_real,
void* p_c_imag, void* p_c_imag,
void* p_aux, void* p_workspace,
void* p_aux_2,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -871,8 +868,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -871,8 +868,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static_cast<const BDataType*>(p_b_imag), static_cast<const BDataType*>(p_b_imag),
static_cast<CDataType*>(p_c_real), static_cast<CDataType*>(p_c_real),
static_cast<CDataType*>(p_c_imag), static_cast<CDataType*>(p_c_imag),
static_cast<CDataType*>(p_aux), static_cast<CDataType*>(p_workspace),
static_cast<CDataType*>(p_aux_2),
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
...@@ -909,6 +905,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -909,6 +905,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return str.str(); return str.str();
} }
std::size_t GetWorkspaceSize([[maybe_unused]] index_t MRaw,
[[maybe_unused]] index_t NRaw,
[[maybe_unused]] index_t KRaw,
[[maybe_unused]] index_t StrideA,
[[maybe_unused]] index_t StrideB,
[[maybe_unused]] index_t StrideC) override
{
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize();
}
}; };
} // namespace device } // namespace device
......
...@@ -72,8 +72,6 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, ...@@ -72,8 +72,6 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
const Tensor<BDataType>& B_imag, const Tensor<BDataType>& B_imag,
Tensor<CDataType>& C_real, Tensor<CDataType>& C_real,
Tensor<CDataType>& C_imag, Tensor<CDataType>& C_imag,
Tensor<CDataType>& Aux,
Tensor<CDataType>& Aux_2,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
...@@ -84,8 +82,8 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, ...@@ -84,8 +82,8 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B_imag.mDesc.GetElementSpace()); DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B_imag.mDesc.GetElementSpace());
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace()); DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace()); DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace());
DeviceMem aux_device_buf(sizeof(CDataType) * Aux.mDesc.GetElementSpace()); DeviceMem workspace_device_buf(cgemmPtr->GetWorkspaceSize(
DeviceMem aux_2_device_buf(sizeof(CDataType) * Aux_2.mDesc.GetElementSpace()); params.M, params.N, params.K, params.StrideA, params.StrideB, params.StrideC));
a_m_k_real_device_buf.ToDevice(A_real.mData.data()); a_m_k_real_device_buf.ToDevice(A_real.mData.data());
a_m_k_imag_device_buf.ToDevice(A_imag.mData.data()); a_m_k_imag_device_buf.ToDevice(A_imag.mData.data());
...@@ -100,8 +98,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, ...@@ -100,8 +98,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(workspace_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(aux_2_device_buf.GetDeviceBuffer()),
params.M, params.M,
params.N, params.N,
params.K, params.K,
...@@ -168,10 +165,6 @@ struct TestCGemm ...@@ -168,10 +165,6 @@ struct TestCGemm
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_imag_device_result( Tensor<CDataType> c_m_n_imag_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> aux(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> aux_2(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto& tensor, auto type) { auto f_generate_tensor_value = [](auto& tensor, auto type) {
using dataType = decltype(type); using dataType = decltype(type);
...@@ -191,9 +184,7 @@ struct TestCGemm ...@@ -191,9 +184,7 @@ struct TestCGemm
c_m_n_real_host_result, c_m_n_real_host_result,
c_m_n_imag_host_result, c_m_n_imag_host_result,
c_m_n_real_device_result, c_m_n_real_device_result,
c_m_n_imag_device_result, c_m_n_imag_device_result);
aux,
aux_2);
} }
auto operator()(DeviceCGemmPtr_& cgemmPtr) auto operator()(DeviceCGemmPtr_& cgemmPtr)
...@@ -221,8 +212,6 @@ struct TestCGemm ...@@ -221,8 +212,6 @@ struct TestCGemm
Tensor<CDataType>& c_host_imag = std::get<5>(host_tensors); Tensor<CDataType>& c_host_imag = std::get<5>(host_tensors);
Tensor<CDataType>& c_device_real = std::get<6>(host_tensors); Tensor<CDataType>& c_device_real = std::get<6>(host_tensors);
Tensor<CDataType>& c_device_imag = std::get<7>(host_tensors); Tensor<CDataType>& c_device_imag = std::get<7>(host_tensors);
Tensor<CDataType>& aux = std::get<8>(host_tensors);
Tensor<CDataType>& aux_2 = std::get<9>(host_tensors);
auto a_element_op = AElementwiseOperation{}; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{}; auto b_element_op = BElementwiseOperation{};
...@@ -254,8 +243,6 @@ struct TestCGemm ...@@ -254,8 +243,6 @@ struct TestCGemm
b_imag, b_imag,
c_device_real, c_device_real,
c_device_imag, c_device_imag,
aux,
aux_2,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
...@@ -340,10 +327,6 @@ struct TestCGemmBF16 ...@@ -340,10 +327,6 @@ struct TestCGemmBF16
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<BF16> c_m_n_imag_device_bf16( Tensor<BF16> c_m_n_imag_device_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<BF16> aux_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<BF16> aux_2_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> a_m_k_real_fp32( Tensor<float> a_m_k_real_fp32(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
...@@ -378,8 +361,6 @@ struct TestCGemmBF16 ...@@ -378,8 +361,6 @@ struct TestCGemmBF16
b_k_n_imag_bf16, b_k_n_imag_bf16,
c_m_n_real_device_bf16, c_m_n_real_device_bf16,
c_m_n_imag_device_bf16, c_m_n_imag_device_bf16,
aux_bf16,
aux_2_bf16,
a_m_k_real_fp32, a_m_k_real_fp32,
a_m_k_imag_fp32, a_m_k_imag_fp32,
b_k_n_real_fp32, b_k_n_real_fp32,
...@@ -408,16 +389,14 @@ struct TestCGemmBF16 ...@@ -408,16 +389,14 @@ struct TestCGemmBF16
const Tensor<BF16>& b_imag_bf16 = std::get<3>(host_tensors); const Tensor<BF16>& b_imag_bf16 = std::get<3>(host_tensors);
Tensor<BF16>& c_real_device_bf16 = std::get<4>(host_tensors); Tensor<BF16>& c_real_device_bf16 = std::get<4>(host_tensors);
Tensor<BF16>& c_imag_device_bf16 = std::get<5>(host_tensors); Tensor<BF16>& c_imag_device_bf16 = std::get<5>(host_tensors);
Tensor<BF16>& aux_bf16 = std::get<6>(host_tensors); Tensor<float>& a_real_fp32 = std::get<6>(host_tensors);
Tensor<BF16>& aux_2_bf16 = std::get<7>(host_tensors); Tensor<float>& a_imag_fp32 = std::get<7>(host_tensors);
Tensor<float>& a_real_fp32 = std::get<8>(host_tensors); Tensor<float>& b_real_fp32 = std::get<8>(host_tensors);
Tensor<float>& a_imag_fp32 = std::get<9>(host_tensors); Tensor<float>& b_imag_fp32 = std::get<9>(host_tensors);
Tensor<float>& b_real_fp32 = std::get<10>(host_tensors); Tensor<float>& c_real_host_fp32 = std::get<10>(host_tensors);
Tensor<float>& b_imag_fp32 = std::get<11>(host_tensors); Tensor<float>& c_imag_host_fp32 = std::get<11>(host_tensors);
Tensor<float>& c_real_host_fp32 = std::get<12>(host_tensors); Tensor<float>& c_real_device_fp32 = std::get<12>(host_tensors);
Tensor<float>& c_imag_host_fp32 = std::get<13>(host_tensors); Tensor<float>& c_imag_device_fp32 = std::get<13>(host_tensors);
Tensor<float>& c_real_device_fp32 = std::get<14>(host_tensors);
Tensor<float>& c_imag_device_fp32 = std::get<15>(host_tensors);
auto a_element_op = AElementwiseOperation{}; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{}; auto b_element_op = BElementwiseOperation{};
...@@ -450,8 +429,6 @@ struct TestCGemmBF16 ...@@ -450,8 +429,6 @@ struct TestCGemmBF16
b_imag_bf16, b_imag_bf16,
c_real_device_bf16, c_real_device_bf16,
c_imag_device_bf16, c_imag_device_bf16,
aux_bf16,
aux_2_bf16,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment