Commit 5215f11d authored by rocking's avatar rocking
Browse files

Wrtie out the e for debug.

This could be remove and use h for instead
parent 3b97076d
......@@ -48,7 +48,6 @@ using BLayout = Col;
using D0Layout = Row;
using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
using HLayout = Row;
using AElementOp = PassThrough;
......@@ -67,6 +66,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
......@@ -88,6 +95,8 @@ auto f_host_tensor_descriptor2d =
int main()
{
bool do_verification = true;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
......@@ -107,6 +116,7 @@ int main()
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<HDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
......@@ -122,6 +132,7 @@ int main()
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(HDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
......@@ -144,6 +155,7 @@ int main()
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
e_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(),
M,
N,
......@@ -164,4 +176,29 @@ int main()
}
invoker.Run(argument, StreamConfig{nullptr, false});
if(do_verification)
{
Tensor<AccDataType> c_m_n_host(HostTensorDescriptor{M, N});
Tensor<HDataType> e_m_n_host(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host(m, n), c_m_n_host(m, n), d0_n(n), d1_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n.mData.data());
return ck::utils::check_err(e_m_n, e_m_n_host) ? 0 : 1;
}
}
......@@ -463,6 +463,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std::array<const void*, NumDTensor> p_ds_grid,
const void* p_gamma_grid,
const void* p_beta_grid,
void* p_e_grid,
void* p_h_grid,
index_t MRaw,
index_t NRaw,
......@@ -479,7 +480,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{nullptr},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_welford_mean_grid_{nullptr},
p_welford_var_grid_{nullptr},
p_welford_count_grid_{nullptr},
......@@ -509,9 +510,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
// TODO - hipFree
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
// TODO - GetWorkSpaceSize(), let user hipMalloc the memory
int gemm_welford_size = MRaw * gemm_nblock_;
hip_check_error(
hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
......@@ -770,6 +769,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_e,
void* p_h,
index_t MRaw,
index_t NRaw,
......@@ -789,6 +789,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_ds,
p_gamma,
p_beta,
p_e,
p_h,
MRaw,
NRaw,
......@@ -812,6 +813,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_e,
void* p_h,
index_t MRaw,
index_t NRaw,
......@@ -831,6 +833,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_ds,
p_gamma,
p_beta,
p_e,
p_h,
MRaw,
NRaw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment