Commit 7392e40c authored by Anthony Chang's avatar Anthony Chang
Browse files

make C0 precision type consistent with C

parent ac6977f7
......@@ -51,9 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl
// clang-format on
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType>
template <typename InDataType, typename OutDataType, typename ComputeDataType>
void Layernorm(Tensor<OutDataType>& result,
const Tensor<InDataType>& acc, // MxN
const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN
......@@ -66,9 +66,9 @@ void Layernorm(Tensor<OutDataType>& result,
size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1];
Tensor<InDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<InDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<InDataType> acc_layernorm(acc.mDesc);
Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<ComputeDataType> acc_layernorm(acc.mDesc);
// add bias
acc_layernorm.ForEach([&](auto& self, auto idx) {
......@@ -78,8 +78,8 @@ void Layernorm(Tensor<OutDataType>& result,
// reduce N dim
for(size_t i = 0; i < M; i++)
{
InDataType sum_acc_sq = 0;
InDataType sum_acc = 0;
ComputeDataType sum_acc_sq = 0;
ComputeDataType sum_acc = 0;
for(size_t j = 0; j < N; j++)
{
sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j);
......@@ -177,9 +177,9 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<AccDataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<AccDataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<AccDataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<CDataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<CDataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<CDataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
......@@ -205,18 +205,18 @@ int main(int argc, char* argv[])
}
// TODO ANT: test other init
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<CDataType>{0, 2});
c0_n_beta.GenerateTensorValue(GeneratorTensor_2<CDataType>{0, 5});
c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
c0_n_bias.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
c0_n_gamma.GenerateTensorValue(GeneratorTensor_1<AccDataType>{2});
c0_n_beta.GenerateTensorValue(GeneratorTensor_1<AccDataType>{2});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_bias_buf(sizeof(AccDataType) * c0_n_bias.mDesc.GetElementSpace());
DeviceMem c0_gamma_buf(sizeof(AccDataType) * c0_n_gamma.mDesc.GetElementSpace());
DeviceMem c0_beta_buf(sizeof(AccDataType) * c0_n_beta.mDesc.GetElementSpace());
DeviceMem c0_bias_buf(sizeof(CDataType) * c0_n_bias.mDesc.GetElementSpace());
DeviceMem c0_gamma_buf(sizeof(CDataType) * c0_n_gamma.mDesc.GetElementSpace());
DeviceMem c0_beta_buf(sizeof(CDataType) * c0_n_beta.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......@@ -234,9 +234,9 @@ int main(int argc, char* argv[])
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(c0_bias_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(c0_gamma_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(c0_beta_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c0_bias_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c0_gamma_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c0_beta_buf.GetDeviceBuffer()),
M,
N,
K,
......
......@@ -423,9 +423,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
const CShuffleDataType* p_c0_bias,
const CShuffleDataType* p_c0_gamma,
const CShuffleDataType* p_c0_beta,
const CDataType* p_c0_bias,
const CDataType* p_c0_gamma,
const CDataType* p_c0_beta,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -470,9 +470,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const CShuffleDataType* p_c0_bias_;
const CShuffleDataType* p_c0_gamma_;
const CShuffleDataType* p_c0_beta_;
const CDataType* p_c0_bias_;
const CDataType* p_c0_gamma_;
const CDataType* p_c0_beta_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
......@@ -530,7 +530,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
CShuffleDataType, // intermediate data type
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -568,7 +567,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
CShuffleDataType, // intermediate data type
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -632,9 +630,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const CShuffleDataType* p_c0_bias,
const CShuffleDataType* p_c0_gamma,
const CShuffleDataType* p_c0_beta,
const CDataType* p_c0_bias,
const CDataType* p_c0_gamma,
const CDataType* p_c0_beta,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -684,9 +682,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const CShuffleDataType*>(p_c0_bias),
static_cast<const CShuffleDataType*>(p_c0_gamma),
static_cast<const CShuffleDataType*>(p_c0_beta),
static_cast<const CDataType*>(p_c0_bias),
static_cast<const CDataType*>(p_c0_gamma),
static_cast<const CDataType*>(p_c0_beta),
MRaw,
NRaw,
KRaw,
......
......@@ -18,7 +18,6 @@ namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename FloatCShuffle,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -35,9 +34,9 @@ __global__ void
kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN
const FloatCShuffle* __restrict__ p_c0_bias_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_gamma_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_beta_grid, // 1xN
const FloatC* __restrict__ p_c0_bias_grid, // 1xN
const FloatC* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
......@@ -365,9 +364,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatCShuffle* __restrict__ p_c0_bias_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_gamma_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_beta_grid, // 1xN
const FloatC* __restrict__ p_c0_bias_grid, // 1xN
const FloatC* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC* __restrict__ p_c0_beta_grid, // 1xN
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
......@@ -764,7 +763,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatC>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// TODO ANT: incorporate in singly defined p_shared. calculate proper total size in
......@@ -819,8 +818,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin, tensor_operation::element_wise::PassThrough{}};
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle,
FloatReduceAcc,
FloatC,
FloatC,
decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment