"tests/vscode:/vscode.git/clone" did not exist on "aa250d0afacc6a0fac7c14126b3ba3cf30a25aa0"
Commit 8e2d0ae7 authored by rocking's avatar rocking
Browse files

verify gpu kernel with host code

parent 8b7aeb35
......@@ -129,7 +129,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<GammaDataType>& beta_n,
const Tensor<BetaDataType>& beta_n,
A_functor a_element_op,
B_functor b_element_op,
C_functor c_element_op,
......
......@@ -40,10 +40,54 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
1, // SliceM
8, // SliceK
1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector
1, // SrcScalarPerVector
1, // AffineVecDim (0=M, 1=K)
1, // AffineScalarPerVector
8>; // OutScalarPerVector
1>; // OutScalarPerVector
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType>
void host_layernorm2d(const Tensor<XDataType>& x_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n,
int M,
int N,
AccDataType epislon = 1e-4)
{
Tensor<AccDataType> mean({M});
Tensor<AccDataType> var({M});
for(int m = 0; m < M; ++m)
{
mean(m) = 0;
var(m) = 0;
for(int n = 0; n < N; ++n)
{
auto x_val = ck::type_convert<AccDataType>(x_m_n(m, n));
mean(m) += x_val;
var(m) += x_val * x_val;
}
mean(m) = mean(m) / N;
var(m) = (var(m) / N) - (mean(m) * mean(m));
}
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
auto x_val = ck::type_convert<AccDataType>(x_m_n(m, n));
auto y_val = (x_val - mean(m)) / sqrt(var(m) + epislon);
y_val = (y_val * gamma_n(n)) + beta_n(n);
y_m_n(m, n) = ck::type_convert<YDataType>(y_val);
}
}
}
int main()
{
......@@ -78,6 +122,8 @@ int main()
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
beta_dev.ToDevice(beta.mData.data());
auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer({M, N},
......@@ -100,5 +146,14 @@ int main()
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true;
{
Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride));
host_layernorm2d<XDataType, GammaDataType, BetaDataType, YDataType, AccDataType>(
x, gamma, beta, host_y, M, N);
y_dev.FromDevice(y.mData.data());
pass &=
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
}
return (pass ? 0 : 1);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment