Commit 068fb458 authored by liuhy's avatar liuhy
Browse files

修改ck代码适配gfx926

parent acd8b8ea
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip> #include <iomanip>
#include <vector> #include <vector>
#include <iostream> #include <iostream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp" #include "ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp"
using ADataType = ck::half_t; // Input 1 using ADataType = ck::half_t; // Input 1
using BDataType = ck::half_t; // Input 2 using BDataType = ck::half_t; // Input 2
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using XElementwiseOperation = ck::tensor_operation::element_wise::Add; using XElementwiseOperation = ck::tensor_operation::element_wise::Add;
using YElementwiseOperation = ck::tensor_operation::element_wise::PassThrough; using YElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
struct SimpleDeviceMem struct SimpleDeviceMem
{ {
SimpleDeviceMem() = delete; SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{} SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{ {
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size); (void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
} }
void* GetDeviceBuffer() { return p_mem_; } void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); } ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_; void* p_mem_;
}; };
int main() int main()
{ {
bool time_kernel = true; bool time_kernel = true;
ck::index_t M = 48 * 256; ck::index_t M = 48 * 256;
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t Stride = N; ck::index_t Stride = N;
auto mn_size = (M - 1) * Stride + N; auto mn_size = (M - 1) * Stride + N;
SimpleDeviceMem a_dev_buf(sizeof(ADataType) * mn_size); SimpleDeviceMem a_dev_buf(sizeof(ADataType) * mn_size);
SimpleDeviceMem b_dev_buf(sizeof(BDataType) * mn_size); SimpleDeviceMem b_dev_buf(sizeof(BDataType) * mn_size);
SimpleDeviceMem gamma_dev_buf(sizeof(GammaDataType) * N); SimpleDeviceMem gamma_dev_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_dev_buf(sizeof(BetaDataType) * N); SimpleDeviceMem beta_dev_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size); SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size);
std::array<const void*, 2> ab_input = {a_dev_buf.GetDeviceBuffer(), std::array<const void*, 2> ab_input = {a_dev_buf.GetDeviceBuffer(),
b_dev_buf.GetDeviceBuffer()}; b_dev_buf.GetDeviceBuffer()};
std::vector<ck::index_t> abStride = {Stride, 1}; std::vector<ck::index_t> abStride = {Stride, 1};
std::array<std::vector<ck::index_t>, 2> abStrides = {abStride, abStride}; std::array<std::vector<ck::index_t>, 2> abStrides = {abStride, abStride};
using DeviceOp = ck::tensor_operation::device::DeviceElementwiseNormalization< using DeviceOp = ck::tensor_operation::device::DeviceElementwiseNormalization<
ck::Tuple<ADataType, BDataType>, ck::Tuple<ADataType, BDataType>,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
XElementwiseOperation, XElementwiseOperation,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances(); DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name; std::string best_op_name;
bool found = false; bool found = false;
int best_op_id = -1; int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max(); float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device operation instances // profile device operation instances
std::cout << "Run all instances and do timing" << std::endl; std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths
abStrides, abStrides,
{0, 1}, // gammaStrides {0, 1}, // gammaStrides
{0, 1}, // betaStrides {0, 1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
ab_input, ab_input,
gamma_dev_buf.GetDeviceBuffer(), gamma_dev_buf.GetDeviceBuffer(),
beta_dev_buf.GetDeviceBuffer(), beta_dev_buf.GetDeviceBuffer(),
y_dev_buf.GetDeviceBuffer(), y_dev_buf.GetDeviceBuffer(),
XElementwiseOperation{}, XElementwiseOperation{},
YElementwiseOperation{}); YElementwiseOperation{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_byte = sizeof(ADataType) * M * N + sizeof(BDataType) * M * N + std::size_t num_byte = sizeof(ADataType) * M * N + sizeof(BDataType) * M * N +
sizeof(GammaDataType) * N + sizeof(BetaDataType) * N + sizeof(GammaDataType) * N + sizeof(BetaDataType) * N +
sizeof(YDataType) * M * N; sizeof(YDataType) * M * N;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl; << op_name << std::endl;
if(ave_time < best_ave_time) if(ave_time < best_ave_time)
{ {
found = true; found = true;
best_op_id = i; best_op_id = i;
best_op_name = op_name; best_op_name = op_name;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
} }
} }
else else
{ {
std::cout << op_name << " does not support this problem" << std::endl; std::cout << op_name << " does not support this problem" << std::endl;
} }
} }
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl; << best_op_name << std::endl;
// run the best intance // run the best intance
{ {
auto& op_ptr = op_ptrs[best_op_id]; auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths
abStrides, abStrides,
{1}, // gammaStrides {1}, // gammaStrides
{1}, // betaStrides {1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
ab_input, ab_input,
gamma_dev_buf.GetDeviceBuffer(), gamma_dev_buf.GetDeviceBuffer(),
beta_dev_buf.GetDeviceBuffer(), beta_dev_buf.GetDeviceBuffer(),
y_dev_buf.GetDeviceBuffer(), y_dev_buf.GetDeviceBuffer(),
XElementwiseOperation{}, XElementwiseOperation{},
YElementwiseOperation{}); YElementwiseOperation{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
} }
std::cout << "Done" << std::endl; std::cout << "Done" << std::endl;
} }
return 0; return 0;
} }
...@@ -9,7 +9,10 @@ cd client_example/build ...@@ -9,7 +9,10 @@ cd client_example/build
``` ```
```bash ```bash
cmake -D CMAKE_CXX_COMPILER=${ROCM_PATH}/bin/hipcc -D CMAKE_PREFIX_PATH="${ROCM_PATH};${PATH_TO_CK_INSTALL_DIRECTORY}" .. cmake \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \
..
``` ```
### Build client example ### Build client example
......
...@@ -27,7 +27,6 @@ message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLA ...@@ -27,7 +27,6 @@ message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLA
FetchContent_Declare( FetchContent_Declare(
googletest googletest
GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git
# GIT_REPOSITORY /work/home/zhangshao/installer/googletest
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
) )
......
...@@ -7,7 +7,7 @@ cmake ...@@ -7,7 +7,7 @@ cmake
-D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_CXX_FLAGS="-O3" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx906;gfx926" \ -D GPU_TARGETS="gfx906;gfx926" \
-D CMAKE_INSTALL_PREFIX=~/composable_kernel/install_ck \ -D CMAKE_INSTALL_PREFIX=~/composable_kernel-develop/install_ck \
.. ..
cd - cd -
add_example_executable(example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp) add_example_executable(example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp)
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <getopt.h> #include <getopt.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
using ADataType = ck::half_t; // Input 1 using ADataType = ck::half_t; // Input 1
using BDataType = ck::half_t; // Input 2 using BDataType = ck::half_t; // Input 2
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using XElementwiseOperation = ck::tensor_operation::element_wise::Add; using XElementwiseOperation = ck::tensor_operation::element_wise::Add;
using YElementwiseOperation = ck::tensor_operation::element_wise::PassThrough; using YElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
// X = Elementwise(input1, input2, input3, ...) // X = Elementwise(input1, input2, input3, ...)
// Y = Layernorm(X, beta, gamma) // Y = Layernorm(X, beta, gamma)
using DeviceInstance = ck::tensor_operation::device::DeviceElementwiseNormalizationImpl< using DeviceInstance = ck::tensor_operation::device::DeviceElementwiseNormalizationImpl<
ck::Tuple<ADataType, BDataType>, ck::Tuple<ADataType, BDataType>,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
XElementwiseOperation, XElementwiseOperation,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
8, // ClusterM 8, // ClusterM
32, // ClusterK 32, // ClusterK
1, // SliceM 1, // SliceM
32, // SliceK 32, // SliceK
1, // SrcVecDim (0=M, 1=K) 1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector 8, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K) 1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 8, // BetaScalarPerVector
8>; // OutScalarPerVector 8>; // OutScalarPerVector
template <typename HostTensorA, typename HostTensorB, typename HostTensorC, typename Functor> template <typename HostTensorA, typename HostTensorB, typename HostTensorC, typename Functor>
void host_elementwise2D(HostTensorC& C, void host_elementwise2D(HostTensorC& C,
const HostTensorA& A, const HostTensorA& A,
const HostTensorB& B, const HostTensorB& B,
const std::vector<std::size_t>& shape, const std::vector<std::size_t>& shape,
Functor functor) Functor functor)
{ {
using ctype = ck::remove_reference_t<decltype(C(0, 0))>; using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
for(std::size_t m = 0; m < shape[0]; ++m) for(std::size_t m = 0; m < shape[0]; ++m)
for(std::size_t n = 0; n < shape[1]; ++n) for(std::size_t n = 0; n < shape[1]; ++n)
{ {
auto a_val = A(m, n); auto a_val = A(m, n);
auto b_val = B(m, n); auto b_val = B(m, n);
ctype c_val = 0; ctype c_val = 0;
functor(c_val, a_val, b_val); functor(c_val, a_val, b_val);
C(m, n) = c_val; C(m, n) = c_val;
} }
} }
int main() int main()
{ {
bool time_kernel = true; bool time_kernel = true;
ck::index_t M = 48 * 256; ck::index_t M = 48 * 256;
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t Stride = N; ck::index_t Stride = N;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}), return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride})); std::vector<std::size_t>({stride}));
}; };
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1})); std::vector<std::size_t>({stride, 1}));
}; };
Tensor<ADataType> a(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<ADataType> a(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<BDataType> b(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<BDataType> b(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1)); Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1)); Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1));
Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride));
a.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
gamma.GenerateTensorValue(GeneratorTensor_2<GammaDataType>{-5, 5}); gamma.GenerateTensorValue(GeneratorTensor_2<GammaDataType>{-5, 5});
beta.GenerateTensorValue(GeneratorTensor_2<BetaDataType>{-5, 5}); beta.GenerateTensorValue(GeneratorTensor_2<BetaDataType>{-5, 5});
DeviceMem a_dev(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_dev(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_dev(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); DeviceMem b_dev(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
a_dev.ToDevice(a.mData.data()); a_dev.ToDevice(a.mData.data());
b_dev.ToDevice(b.mData.data()); b_dev.ToDevice(b.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
beta_dev.ToDevice(beta.mData.data()); beta_dev.ToDevice(beta.mData.data());
std::array<const void*, 2> input = {a_dev.GetDeviceBuffer(), b_dev.GetDeviceBuffer()}; std::array<const void*, 2> input = {a_dev.GetDeviceBuffer(), b_dev.GetDeviceBuffer()};
auto device_instance = DeviceInstance{}; auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer( auto argument_ptr = device_instance.MakeArgumentPointer(
{M, N}, {M, N},
{ {
std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}, std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()},
std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}, std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()},
}, },
{0, 1}, {0, 1},
{0, 1}, {0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
{1}, {1},
1e-4, 1e-4,
input, input,
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
XElementwiseOperation{}, XElementwiseOperation{},
YElementwiseOperation{}); YElementwiseOperation{});
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{ {
std::cout << "The runtime parameters are not supported" << std::endl; std::cout << "The runtime parameters are not supported" << std::endl;
return 1; return 1;
}; };
auto invoker_ptr = device_instance.MakeInvokerPointer(); auto invoker_ptr = device_instance.MakeInvokerPointer();
float ela_time = 0; float ela_time = 0;
ela_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); ela_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float data_mem_size = M * N * sizeof(ADataType) + M * N * sizeof(BDataType) + float data_mem_size = M * N * sizeof(ADataType) + M * N * sizeof(BDataType) +
M * N * sizeof(YDataType) + N * sizeof(GammaDataType) + M * N * sizeof(YDataType) + N * sizeof(GammaDataType) +
N * sizeof(BetaDataType); N * sizeof(BetaDataType);
float bandwidth = data_mem_size * 1000 / ela_time / 1024 / 1024 / 1024; float bandwidth = data_mem_size * 1000 / ela_time / 1024 / 1024 / 1024;
std::cout << "Bandwidth is : " << bandwidth << "GB/s . " << std::endl; std::cout << "Bandwidth is : " << bandwidth << "GB/s . " << std::endl;
std::cout << "Time elapase is : " << ela_time << " ms . " << std::endl; std::cout << "Time elapase is : " << ela_time << " ms . " << std::endl;
bool pass = true; bool pass = true;
{ {
std::vector<std::size_t> mn = {static_cast<unsigned long>(M), std::vector<std::size_t> mn = {static_cast<unsigned long>(M),
static_cast<unsigned long>(N)}; static_cast<unsigned long>(N)};
Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
host_elementwise2D<Tensor<ADataType>, host_elementwise2D<Tensor<ADataType>,
Tensor<BDataType>, Tensor<BDataType>,
Tensor<XDataType>, Tensor<XDataType>,
XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{}); XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{});
Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride));
using ReferenceInstance = using ReferenceInstance =
ck::tensor_operation::host::ReferenceLayernorm<XDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument =
ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4); ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4);
auto ref_invoker = ref.MakeInvoker(); auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= pass &=
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3); ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
if(!(pass)) if(!(pass))
{ {
std::cout << "layernorm wrong" << std::endl; std::cout << "layernorm wrong" << std::endl;
} }
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
// buffer resource // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)||defined(__gfx926__) || defined(__gfx908__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx926__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code defined(__gfx90a__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
...@@ -46,7 +46,7 @@ ...@@ -46,7 +46,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32 #define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \ #elif defined(__gfx906__) || defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx1030__) // for GPU code defined(__gfx1030__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
......
...@@ -225,13 +225,13 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_ ...@@ -225,13 +225,13 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction = constexpr auto threadwise_contraction =
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA, FloatA,
FloatB, FloatA,
FloatC, FloatC,
decltype(a_thread_desc_bk0_bm0_bm1_bk1_), decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_), decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
...@@ -394,8 +394,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_ ...@@ -394,8 +394,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB, FloatB, // src
FloatB, FloatA, // dst
decltype(b_block_desc_bk0_bn0_bn1_bk1_), decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_), decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename AccDataType,
typename YDataType, typename YDataType,
typename XElementwiseOperation, typename XElementwiseOperation,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceElementwiseNormalization : public BaseOperator struct DeviceElementwiseNormalization : public BaseOperator
{ {
static constexpr int NumInput = InDataTypeTuple::Size(); static constexpr int NumInput = InDataTypeTuple::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths, MakeArgumentPointer(const std::vector<index_t> lengths,
const std::array<std::vector<index_t>, NumInput> inStridesArray, const std::array<std::vector<index_t>, NumInput> inStridesArray,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, AccDataType epsilon,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
XElementwiseOperation x_elementwise_op, XElementwiseOperation x_elementwise_op,
YElementwiseOperation y_elementwise_op) = 0; YElementwiseOperation y_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename AccDataType,
typename YDataType, typename YDataType,
typename XElementwiseOperation, typename XElementwiseOperation,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceElementwiseNormalizationPtr = using DeviceElementwiseNormalizationPtr =
std::unique_ptr<DeviceElementwiseNormalization<InDataTypeTuple, std::unique_ptr<DeviceElementwiseNormalization<InDataTypeTuple,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
XElementwiseOperation, XElementwiseOperation,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>>; NumReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -134,7 +134,7 @@ __global__ void ...@@ -134,7 +134,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)||defined(__gfx926__) || defined(__gfx1030__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx926__) || defined(__gfx1030__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -709,7 +709,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -709,7 +709,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{ {
return false; return false;
} }
......
...@@ -106,7 +106,7 @@ __global__ void ...@@ -106,7 +106,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx1030__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx926__) || defined(__gfx1030__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -600,7 +600,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -600,7 +600,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(!(ck::get_device_name() == "gfx906" ||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{ {
return false; return false;
} }
......
...@@ -1391,7 +1391,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1391,7 +1391,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906"||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030"))
{ {
return false; return false;
} }
......
...@@ -205,6 +205,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -205,6 +205,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
using GridwiseGemm = using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize, GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType, ADataType,
BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -364,6 +365,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -364,6 +365,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const auto kernel = const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType, ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
...@@ -390,6 +392,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -390,6 +392,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const auto kernel = const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType, ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
...@@ -416,6 +419,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -416,6 +419,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const auto kernel = const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType, ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
...@@ -442,6 +446,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -442,6 +446,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const auto kernel = const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType, ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
...@@ -483,7 +488,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -483,7 +488,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" ||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030") if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" #include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck {
// X = Elementwise(input1, input2, input3, ...) // X = Elementwise(input1, input2, input3, ...)
// Y = Normalization(X, beta, gamma) // Y = Normalization(X, beta, gamma)
template <typename InDataTypePointerTuple, template <typename InDataTypePointerTuple,
typename XDataType, typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename AccDataType,
typename XElementwiseOperation, typename XElementwiseOperation,
typename YElementwiseOperation, typename YElementwiseOperation,
typename InGrid2dDescTuple, typename InGrid2dDescTuple,
typename GridDesc_M_K, typename GridDesc_M_K,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t XSrcVectorDim, index_t XSrcVectorDim,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t GammaSrcVectorDim, index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr index_t NumInput = InDataTypePointerTuple::Size(); static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder = using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type; typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder = using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type; typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford = using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>; ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType, using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>; ThreadClusterArrangeOrder>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
static constexpr auto XThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{}; static constexpr auto XThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
static constexpr auto GammaThreadBufferNumber = Number<KThreadSliceSize / GammaSrcVectorSize>{}; static constexpr auto GammaThreadBufferNumber = Number<KThreadSliceSize / GammaSrcVectorSize>{};
static constexpr auto BetaThreadBufferNumber = Number<KThreadSliceSize / BetaSrcVectorSize>{}; static constexpr auto BetaThreadBufferNumber = Number<KThreadSliceSize / BetaSrcVectorSize>{};
static constexpr auto YThreadBufferNumber = Number<KThreadSliceSize / YDstVectorSize>{}; static constexpr auto YThreadBufferNumber = Number<KThreadSliceSize / YDstVectorSize>{};
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
int thread_k_cluster_id) int thread_k_cluster_id)
{ {
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
int kPerThread = int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0) if(kPerBlockTail > 0)
{ {
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
int thread_max_len = int thread_max_len =
(thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
int delta = thread_max_len - kPerBlockTail; int delta = thread_max_len - kPerBlockTail;
delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
kPerThread += XSrcVectorSize - delta; kPerThread += XSrcVectorSize - delta;
}); });
} }
return kPerThread; return kPerThread;
} }
__device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple,
const GridDesc_M_K& x_grid_desc_m_k, const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const InDataTypePointerTuple p_in_global_tuple, const InDataTypePointerTuple p_in_global_tuple,
XDataType* const __restrict__ p_x_lds, XDataType* const __restrict__ p_x_lds,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
const XElementwiseOperation x_elementwise_op, const XElementwiseOperation x_elementwise_op,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
if constexpr(SweepOnce) if constexpr(SweepOnce)
{ {
num_k_block_tile_iteration = 1; num_k_block_tile_iteration = 1;
} }
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
const index_t grid_size = get_grid_size(); const index_t grid_size = get_grid_size();
auto in_global_buf_tuple = generate_tuple( auto in_global_buf_tuple = generate_tuple(
[&](auto I) { [&](auto I) {
static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() == static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() ==
2); // matrix dimension 2); // matrix dimension
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize()); p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize());
}, },
Number<NumInput>{}); Number<NumInput>{});
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size);
auto in_thread_buf_tuple = generate_tuple( auto in_thread_buf_tuple = generate_tuple(
[&](auto) { [&](auto) {
return generate_tuple( return generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, AccDataType,
MThreadSliceSize * XSrcVectorSize, MThreadSliceSize * XSrcVectorSize,
true>{}; true>{};
}, },
Number<NumInput>{}); Number<NumInput>{});
}, },
Number<XThreadBufferNumber>{}); Number<XThreadBufferNumber>{});
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, AccDataType,
MThreadSliceSize * XSrcVectorSize, MThreadSliceSize * XSrcVectorSize,
true>{}; true>{};
}, },
Number<XThreadBufferNumber>{}); Number<XThreadBufferNumber>{});
auto gamma_thread_buf = generate_tuple( auto gamma_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, AccDataType,
MThreadSliceSize * GammaSrcVectorSize, MThreadSliceSize * GammaSrcVectorSize,
true>{}; true>{};
}, },
Number<GammaThreadBufferNumber>{}); Number<GammaThreadBufferNumber>{});
auto beta_thread_buf = generate_tuple( auto beta_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, AccDataType,
MThreadSliceSize * BetaSrcVectorSize, MThreadSliceSize * BetaSrcVectorSize,
true>{}; true>{};
}, },
Number<BetaThreadBufferNumber>{}); Number<BetaThreadBufferNumber>{});
auto y_thread_buf = generate_tuple( auto y_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType, AccDataType,
MThreadSliceSize * YDstVectorSize, MThreadSliceSize * YDstVectorSize,
true>{}; true>{};
}, },
Number<YThreadBufferNumber>{}); Number<YThreadBufferNumber>{});
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>; using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
auto in_global_load_tuple = generate_tuple( auto in_global_load_tuple = generate_tuple(
[&](auto I) { [&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>; using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>; using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType, return ThreadwiseTensorSliceTransfer_v2<DataType,
AccDataType, AccDataType,
decltype(in_grid_2d_desc_tuple[I]), decltype(in_grid_2d_desc_tuple[I]),
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
XSrcVectorDim, XSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
1, 1,
false>{ false>{
in_grid_2d_desc_tuple[I], in_grid_2d_desc_tuple[I],
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * XSrcVectorSize)}; thread_k_cluster_id * XSrcVectorSize)};
}, },
Number<NumInput>{}); Number<NumInput>{});
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType, auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
XSrcVectorDim, XSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
1, 1,
true>( true>(
x_grid_desc_m_k, x_grid_desc_m_k,
make_multi_index(thread_m_cluster_id * MThreadSliceSize, make_multi_index(thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * XSrcVectorSize)); thread_k_cluster_id * XSrcVectorSize));
auto threadwise_gamma_load = auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType, ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
GammaSrcVectorDim, GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
1, 1,
true>( true>(
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * GammaSrcVectorSize)); thread_k_cluster_id * GammaSrcVectorSize));
auto threadwise_beta_load = auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType, ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
1, 1,
true>( true>(
beta_grid_desc_m_k, beta_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * BetaSrcVectorSize)); thread_k_cluster_id * BetaSrcVectorSize));
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
PassThrough pass_through_op; PassThrough pass_through_op;
auto threadwise_x_store = auto threadwise_x_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
XDataType, XDataType,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
GridDesc_M_K, GridDesc_M_K,
PassThrough, PassThrough,
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
XSrcVectorDim, XSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
x_grid_desc_m_k, x_grid_desc_m_k,
make_multi_index(thread_m_cluster_id * MThreadSliceSize, make_multi_index(thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * XSrcVectorSize), thread_k_cluster_id * XSrcVectorSize),
pass_through_op); pass_through_op);
auto threadwise_y_store = auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType, YDataType,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
GridDesc_M_K, GridDesc_M_K,
YElementwiseOperation, YElementwiseOperation,
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
y_grid_desc_m_k, y_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
// Copy x from Cache // Copy x from Cache
// one pass: fwd, second pass: bwd // one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford(); auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<AccDataType>(0.0f); mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
var_thread_buf(I) = type_convert<AccDataType>(0.0f); var_thread_buf(I) = type_convert<AccDataType>(0.0f);
}); });
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{ {
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, NumInput, 1>{}([&](auto I) { // input load loop static_for<0, NumInput, 1>{}([&](auto I) { // input load loop
in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I], in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I],
in_global_buf_tuple[I], in_global_buf_tuple[I],
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf_tuple(iK0)(I)); in_thread_buf_tuple(iK0)(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I], in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I],
thread_copy_fwd_step_m_k); thread_copy_fwd_step_m_k);
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// get reference to in data // get reference to in data
const auto in_data_refs = generate_tie( const auto in_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto I) -> const auto& { [&](auto I) -> const auto& {
return in_thread_buf_tuple(iK0)(I)(Number<offset_m_k>{}); return in_thread_buf_tuple(iK0)(I)(Number<offset_m_k>{});
}, },
Number<NumInput>{}); Number<NumInput>{});
// get reference to dst data // get reference to dst data
auto out_data_refs = generate_tie( auto out_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto) -> auto& { return x_thread_buf(iK0)(Number<offset_m_k>{}); }, [&](auto) -> auto& { return x_thread_buf(iK0)(Number<offset_m_k>{}); },
I1); I1);
unpack2(x_elementwise_op, out_data_refs, in_data_refs); unpack2(x_elementwise_op, out_data_refs, in_data_refs);
}); });
}); });
threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf); threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf);
if constexpr(!SweepOnce) if constexpr(!SweepOnce)
{ {
threadwise_x_store.Run(thread_buffer_desc_m_k, threadwise_x_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf(iK0), x_thread_buf(iK0),
x_grid_desc_m_k, x_grid_desc_m_k,
x_lds_val_buf); x_lds_val_buf);
threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k, threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k,
thread_copy_fwd_step_m_k); thread_copy_fwd_step_m_k);
} }
}); });
} }
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0) if constexpr(I > 0)
block_sync_lds(); block_sync_lds();
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
}); });
auto thread_copy_tail_m_k = auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k;
if constexpr(!SweepOnce) if constexpr(!SweepOnce)
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{ {
if constexpr(!SweepOnce) if constexpr(!SweepOnce)
{ {
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_lds_val_buf, x_lds_val_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf(i)); x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
}); });
} }
static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.Run(gamma_grid_desc_m_k, threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf, gamma_global_val_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
gamma_thread_buf(i)); gamma_thread_buf(i));
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k); thread_copy_fwd_step_m_k);
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; divisor;
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) * y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{}); gamma_thread_buf(iK0)(Number<offset_m_k>{});
}); });
}); });
}); });
static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k, threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf, beta_global_val_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
beta_thread_buf(i)); beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k); thread_copy_fwd_step_m_k);
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta // beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) + y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{}); beta_thread_buf(iK0)(Number<offset_m_k>{});
}); });
}); });
}); });
static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { static_for<0, YThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k, threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
y_thread_buf(i), y_thread_buf(i),
y_grid_desc_m_k, y_grid_desc_m_k,
y_global_val_buf); y_global_val_buf);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
}); });
if constexpr(!SweepOnce) if constexpr(!SweepOnce)
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k); 2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k); 2 * thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
} }
} }
}; };
} // namespace ck } // namespace ck
...@@ -117,18 +117,11 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -117,18 +117,11 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool ret=(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
if(!ret){
std::cout<<"M="<<M<<" N="<<N<<" K0="<<K0<<" c_grid_desc_m_n[0]="<<c_grid_desc_m_n.GetLength(I0)<<" c_grid_desc_m_n[1]="<<c_grid_desc_m_n.GetLength(I1)
<<" b_grid_desc_k0_n_k1[0]="<<b_grid_desc_k0_n_k1.GetLength(I0)<<" b_grid_desc_k0_n_k1[2]="<<b_grid_desc_k0_n_k1.GetLength(I2)
<<" a_grid_desc_k0_m_k1[2]="<<a_grid_desc_k0_m_k1.GetLength(I2)
<<" K1="<<K1<<" MPerBlock="<<MPerBlock<<" NPerBlock="<<NPerBlock<<" K0PerBlock="<<K0PerBlock<<std::endl;
}
return ret;
} }
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M0_M1_K1, typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1, typename BGridDesc_K0_N0_N1_K1,
...@@ -30,23 +31,27 @@ __global__ void ...@@ -30,23 +31,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dl_v1r3(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size_of_a =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByteToA() / sizeof(FloatA);
constexpr index_t shared_block_size_of_b =
GridwiseGemm::GetSharedMemoryNumberOfByteToB() / sizeof(FloatB);
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatA p_shared_block_a[shared_block_size_of_a];
__shared__ FloatB p_shared_block_b[shared_block_size_of_b];
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block_a,
p_shared_block_b,
a_grid_desc_k0_m0_m1_k1, a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1, b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11, c_grid_desc_m0_m10_m11_n0_n10_n11,
...@@ -56,7 +61,8 @@ __global__ void ...@@ -56,7 +61,8 @@ __global__ void
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -99,7 +105,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -99,7 +105,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByteToA()
{ {
// TODO: change this. I think it needs multi-dimensional alignment // TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -122,7 +128,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -122,7 +128,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto b_block_aligned_space_size = constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); return 2 * a_block_aligned_space_size * sizeof(FloatA);
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByteToB()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * b_block_aligned_space_size * sizeof(FloatB);
} }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
...@@ -145,14 +177,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -145,14 +177,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); // M0 * N0
return grid_size; return grid_size;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{ {
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; // K0 > K0PerBlock ???
return has_main_k_block_loop; return has_main_k_block_loop;
} }
...@@ -170,7 +202,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -170,7 +202,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlock>{}; const auto M1 = Number<MPerBlock>{}; // 128
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_grid_desc_k0_m0_m1_k1 = const auto a_grid_desc_k0_m0_m1_k1 =
...@@ -178,8 +210,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -178,8 +210,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)), make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), // K0, M, K1
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); // K0, M0, M1, K1
return a_grid_desc_k0_m0_m1_k1; return a_grid_desc_k0_m0_m1_k1;
} }
...@@ -190,7 +222,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -190,7 +222,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlock>{}; const auto N1 = Number<NPerBlock>{}; // 128
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_grid_desc_k0_n0_n1_k1 = const auto b_grid_desc_k0_n0_n1_k1 =
...@@ -198,8 +230,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -198,8 +230,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), // K0, N, K1
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); // K0, N0, N1, K1
return b_grid_desc_k0_n0_n1_k1; return b_grid_desc_k0_n0_n1_k1;
} }
...@@ -210,33 +242,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -210,33 +242,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{}; // 128
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{}; // 128
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
constexpr auto M11 = constexpr auto M11 = // 64
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) * Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) * // S<8, 2> ==> 8*2=16
M1PerThreadM111>{}; M1PerThreadM111>{}; // M1PerThread 4
constexpr auto N11 = constexpr auto N11 = // 64
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) * Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) * // 16
N1PerThreadN111>{}; N1PerThreadN111>{}; // N1PerThread 4
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11; // 2
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11; // 2
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}), // M, N
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); // M0, M10, M11, N0, N10, N11
return c_grid_desc_m0_m10_m11_n0_n10_n11; return c_grid_desc_m0_m10_m11_n0_n10_n11;
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping // what a fuck ???????????? 到底生成了啥
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -252,10 +284,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -252,10 +284,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatA* __restrict__ p_shared_block_a,
FloatB* __restrict__ p_shared_block_b,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
...@@ -304,12 +337,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -304,12 +337,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM // A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); // (16, 128, 2), 2
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM // B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); // (16, 128, 2), 2
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() && a_k0_m_k1_block_desc.GetElementSpaceSize() &&
...@@ -325,8 +358,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -325,8 +358,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, FloatA,
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>, remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1), decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -349,8 +382,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -349,8 +382,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, FloatB,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>, remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_block_desc_k0_n0_n1_k1), decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -368,14 +401,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -368,14 +401,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS // b_mtx[K0PerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize, BlockSize,
FloatAB, FloatA, // todo split a/b
FloatAB, FloatB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
...@@ -400,8 +433,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -400,8 +433,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto b_block_aligned_space_size = math::integer_least_multiple( constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatA* p_a_block_double = p_shared_block_a;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatB* p_b_block_double = p_shared_block_b;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
...@@ -436,7 +469,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -436,7 +469,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); // K / K1(=2)
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -487,7 +520,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -487,7 +520,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock; k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock); } while(k_block_data_begin < K0 - 2 * K0PerBlock); // K0PerBlock = 16
} }
// LDS double buffer: tail // LDS double buffer: tail
......
...@@ -151,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -151,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>( dst_buf(Number<dst_offset>{}) = type_convert<DstData>(
src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]); src_vector.template AsType<SrcData>()[Number<src_vector_offset>{}]);
}); });
}); });
} }
......
...@@ -942,6 +942,10 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -942,6 +942,10 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// u8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
// Convert X to Y // Convert X to Y
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
...@@ -951,6 +955,13 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -951,6 +955,13 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// Convert X to Y
template <>
__host__ __device__ constexpr half_t type_convert<half_t, uint8_t>(uint8_t x)
{
return static_cast<half_t>(x);
}
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
......
...@@ -9,6 +9,31 @@ namespace ck { ...@@ -9,6 +9,31 @@ namespace ck {
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
__device__ void inner_product(const TA& a, const TB& b, TC& c); __device__ void inner_product(const TA& a, const TB& b, TC& c);
template <>
__device__ void inner_product<half2_t, uint8x2_t, float>(const half2_t& a, const uint8x2_t& b, float& c)
{
const vector_type<half_t, 2> a_vector{a};
const vector_type<uint8_t, 2> b_vector{b};
const vector_type<half_t, 2> b_fp16_vector;
static constexpr uint32_t mask_for_elt_01 = 0x05020500;
static constexpr uint32_t mask_for_elt_23 = 0x05030501;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
asm volatile("v_perm_b32 %0,%1,%2,%3;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[0]) : "v"(start_byte_for_fp16), "v"(((uint32_t*)&b_vector.data_)[0]), "v"(mask_for_elt_01));
// asm volatile("v_perm_b32 %0,%1,%2,%3;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[1]) : "v"(start_byte_for_fp16), "v"(((uint32_t*)&b_vector.data_)[0]), "v"(mask_for_elt_23));
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x6480;
asm volatile("v_sub_f16x2 %0, %1, %2;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[0]) : "v"(((uint32_t*)&b_fp16_vector.data_)[0]), "v"(I8s_TO_F16s_MAGIC_NUM));
// asm volatile("v_sub_f16x2 %0, %1, %2;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[1]) : "v"(((uint32_t*)&b_fp16_vector.data_)[1]), "v"(I8s_TO_F16s_MAGIC_NUM));
static_for<0, 2, 1>{}([&](auto i) {
c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) *
type_convert<int32_t>(b_fp16_vector.AsType<half_t>()[i]);
});
}
template <> template <>
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c) __device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
{ {
...@@ -71,12 +96,6 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f ...@@ -71,12 +96,6 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c); c);
} }
template <>
__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
{
c+=static_cast<float>(a*b);
}
template <> template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
......
...@@ -15,6 +15,7 @@ namespace device { ...@@ -15,6 +15,7 @@ namespace device {
namespace instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using I8 = int8_t;
using F32 = float; using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -34,13 +35,13 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< ...@@ -34,13 +35,13 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGemmDl< F16, I8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, F16, I8, F16, PassThrough, PassThrough, PassThrough>>>&
instances) instances)
{ {
add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment