Commit 7392e40c authored by Anthony Chang's avatar Anthony Chang
Browse files

make C0 precision type consistent with C

parent ac6977f7
...@@ -51,9 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl ...@@ -51,9 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl
// clang-format on // clang-format on
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta) // D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType> template <typename InDataType, typename OutDataType, typename ComputeDataType>
void Layernorm(Tensor<OutDataType>& result, void Layernorm(Tensor<OutDataType>& result,
const Tensor<InDataType>& acc, // MxN const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& gamma, // 1xN const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN const Tensor<InDataType>& beta, // 1xN
...@@ -66,9 +66,9 @@ void Layernorm(Tensor<OutDataType>& result, ...@@ -66,9 +66,9 @@ void Layernorm(Tensor<OutDataType>& result,
size_t M = acc.mDesc.GetLengths()[0]; size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1]; size_t N = acc.mDesc.GetLengths()[1];
Tensor<InDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc_sq(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<InDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M}))); Tensor<ComputeDataType> avg_acc(HostTensorDescriptor(std::vector<size_t>({M})));
Tensor<InDataType> acc_layernorm(acc.mDesc); Tensor<ComputeDataType> acc_layernorm(acc.mDesc);
// add bias // add bias
acc_layernorm.ForEach([&](auto& self, auto idx) { acc_layernorm.ForEach([&](auto& self, auto idx) {
...@@ -78,8 +78,8 @@ void Layernorm(Tensor<OutDataType>& result, ...@@ -78,8 +78,8 @@ void Layernorm(Tensor<OutDataType>& result,
// reduce N dim // reduce N dim
for(size_t i = 0; i < M; i++) for(size_t i = 0; i < M; i++)
{ {
InDataType sum_acc_sq = 0; ComputeDataType sum_acc_sq = 0;
InDataType sum_acc = 0; ComputeDataType sum_acc = 0;
for(size_t j = 0; j < N; j++) for(size_t j = 0; j < N; j++)
{ {
sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j); sum_acc_sq += acc_layernorm(i, j) * acc_layernorm(i, j);
...@@ -177,9 +177,9 @@ int main(int argc, char* argv[]) ...@@ -177,9 +177,9 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<AccDataType> acc_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<AccDataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<CDataType> c0_n_bias(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<AccDataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<CDataType> c0_n_gamma(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
Tensor<AccDataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)}))); Tensor<CDataType> c0_n_beta(HostTensorDescriptor(std::vector<size_t>({size_t(N)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
...@@ -205,18 +205,18 @@ int main(int argc, char* argv[]) ...@@ -205,18 +205,18 @@ int main(int argc, char* argv[])
} }
// TODO ANT: test other init // TODO ANT: test other init
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
c0_n_gamma.GenerateTensorValue(GeneratorTensor_2<CDataType>{0, 2});
c0_n_beta.GenerateTensorValue(GeneratorTensor_2<CDataType>{0, 5});
c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0}); c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
c0_n_bias.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
c0_n_gamma.GenerateTensorValue(GeneratorTensor_1<AccDataType>{2});
c0_n_beta.GenerateTensorValue(GeneratorTensor_1<AccDataType>{2});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_bias_buf(sizeof(AccDataType) * c0_n_bias.mDesc.GetElementSpace()); DeviceMem c0_bias_buf(sizeof(CDataType) * c0_n_bias.mDesc.GetElementSpace());
DeviceMem c0_gamma_buf(sizeof(AccDataType) * c0_n_gamma.mDesc.GetElementSpace()); DeviceMem c0_gamma_buf(sizeof(CDataType) * c0_n_gamma.mDesc.GetElementSpace());
DeviceMem c0_beta_buf(sizeof(AccDataType) * c0_n_beta.mDesc.GetElementSpace()); DeviceMem c0_beta_buf(sizeof(CDataType) * c0_n_beta.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
...@@ -234,9 +234,9 @@ int main(int argc, char* argv[]) ...@@ -234,9 +234,9 @@ int main(int argc, char* argv[])
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(c0_bias_buf.GetDeviceBuffer()), static_cast<CDataType*>(c0_bias_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(c0_gamma_buf.GetDeviceBuffer()), static_cast<CDataType*>(c0_gamma_buf.GetDeviceBuffer()),
static_cast<AccDataType*>(c0_beta_buf.GetDeviceBuffer()), static_cast<CDataType*>(c0_beta_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
......
...@@ -423,9 +423,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -423,9 +423,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const CShuffleDataType* p_c0_bias, const CDataType* p_c0_bias,
const CShuffleDataType* p_c0_gamma, const CDataType* p_c0_gamma,
const CShuffleDataType* p_c0_beta, const CDataType* p_c0_beta,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -470,9 +470,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -470,9 +470,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const CShuffleDataType* p_c0_bias_; const CDataType* p_c0_bias_;
const CShuffleDataType* p_c0_gamma_; const CDataType* p_c0_gamma_;
const CShuffleDataType* p_c0_beta_; const CDataType* p_c0_beta_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -530,7 +530,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -530,7 +530,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
CShuffleDataType, // intermediate data type
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -568,7 +567,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -568,7 +567,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
CShuffleDataType, // intermediate data type
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -632,9 +630,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -632,9 +630,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, CDataType* p_c,
const CShuffleDataType* p_c0_bias, const CDataType* p_c0_bias,
const CShuffleDataType* p_c0_gamma, const CDataType* p_c0_gamma,
const CShuffleDataType* p_c0_beta, const CDataType* p_c0_beta,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
...@@ -684,9 +682,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -684,9 +682,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const CShuffleDataType*>(p_c0_bias), static_cast<const CDataType*>(p_c0_bias),
static_cast<const CShuffleDataType*>(p_c0_gamma), static_cast<const CDataType*>(p_c0_gamma),
static_cast<const CShuffleDataType*>(p_c0_beta), static_cast<const CDataType*>(p_c0_beta),
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
......
...@@ -18,7 +18,6 @@ namespace ck { ...@@ -18,7 +18,6 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename FloatCShuffle,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -35,9 +34,9 @@ __global__ void ...@@ -35,9 +34,9 @@ __global__ void
kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN FloatC* __restrict__ p_c_grid, // MxN
const FloatCShuffle* __restrict__ p_c0_bias_grid, // 1xN const FloatC* __restrict__ p_c0_bias_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_gamma_grid, // 1xN const FloatC* __restrict__ p_c0_gamma_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_beta_grid, // 1xN const FloatC* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -365,9 +364,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -365,9 +364,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatCShuffle* __restrict__ p_c0_bias_grid, // 1xN const FloatC* __restrict__ p_c0_bias_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_gamma_grid, // 1xN const FloatC* __restrict__ p_c0_gamma_grid, // 1xN
const FloatCShuffle* __restrict__ p_c0_beta_grid, // 1xN const FloatC* __restrict__ p_c0_beta_grid, // 1xN
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -764,7 +763,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -764,7 +763,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatC>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// TODO ANT: incorporate in singly defined p_shared. calculate proper total size in // TODO ANT: incorporate in singly defined p_shared. calculate proper total size in
...@@ -819,8 +818,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -819,8 +818,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin, tensor_operation::element_wise::PassThrough{}}; true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin, tensor_operation::element_wise::PassThrough{}};
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle, FloatC,
FloatReduceAcc, FloatC,
decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment