Commit eead0864 authored by rocking's avatar rocking
Browse files

Share E and H memory in device op

parent a1cc1504
......@@ -163,7 +163,6 @@ int main()
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<HDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
......@@ -179,7 +178,6 @@ int main()
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(HDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
......@@ -202,7 +200,6 @@ int main()
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
e_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(),
M,
N,
......@@ -250,10 +247,7 @@ int main()
N,
epsilon);
e_device_buf.FromDevice(e_m_n.mData.data());
h_device_buf.FromDevice(h_m_n.mData.data());
pass &= ck::utils::check_err(e_m_n, e_m_n_host, "Error: Incorrect results e_m_n");
pass &=
ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
}
......
......@@ -499,7 +499,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std::array<const void*, NumDTensor> p_ds_grid,
const void* p_gamma_grid,
const void* p_beta_grid,
void* p_e_grid,
void* p_h_grid,
index_t MRaw,
index_t NRaw,
......@@ -516,7 +515,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_e_grid_{static_cast<EDataType*>(p_h_grid)},
p_workspace_mean_{nullptr},
p_workspace_var_{nullptr},
p_workspace_count_{nullptr},
......@@ -938,7 +937,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_e,
void* p_h,
index_t MRaw,
index_t NRaw,
......@@ -958,7 +956,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_ds,
p_gamma,
p_beta,
p_e,
p_h,
MRaw,
NRaw,
......@@ -982,7 +979,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
std::array<const void*, NumDTensor> p_ds,
const void* p_gamma,
const void* p_beta,
void* p_e,
void* p_h,
index_t MRaw,
index_t NRaw,
......@@ -1002,7 +998,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_ds,
p_gamma,
p_beta,
p_e,
p_h,
MRaw,
NRaw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment