"...composable_kernel-1.git" did not exist on "89e1ebd4d5b1bd21fe4ad58fba37cc9f5e17f4a6"
Commit e6715976 authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'develop' into dl_conv_multiple_d

parents ca313a29 10c72ace
...@@ -618,9 +618,9 @@ pipeline { ...@@ -618,9 +618,9 @@ pipeline {
stage('Clang Format') { stage('Clang Format') {
agent{ label rocmnode("nogpu") } agent{ label rocmnode("nogpu") }
environment{ environment{
execute_cmd = "find .. -iname \'*.h\' \ execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -iname \'*.hpp\' \ -o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -iname \'*.cpp\' \ -o -not -path \'*.git*\' -iname \'*.cpp\' \
-o -iname \'*.h.in\' \ -o -iname \'*.h.in\' \
-o -iname \'*.hpp.in\' \ -o -iname \'*.hpp.in\' \
-o -iname \'*.cpp.in\' \ -o -iname \'*.cpp.in\' \
......
add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp)
target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp"
using XDataType = float;
using YDataType = float;
using AccDataType = float;
using ScaleDataType = AccDataType;
using BiasDataType = AccDataType;
using MeanVarDataType = AccDataType;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 4;
constexpr int NumBatchNormReduceDim = 3;
const double epsilon = std::numeric_limits<float>::epsilon();
const double averageFactor = 0.1;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
// In the actual application, the instance index and name are usually from the perf db
static int instance_index = -1;
static std::string instance_name;
int main(int argc, char* argv[])
{
std::array<ck::index_t, Rank> xyLengths{16, 8, 128, 256};
std::array<ck::index_t, Rank> xyStrides{8 * 128 * 256, 128 * 256, 256, 1};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarLengths{256};
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarStrides{1};
std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
ck::index_t numXYElement =
std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies<ck::index_t>());
ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(),
scaleBiasMeanVarLengths.end(),
1,
std::multiplies<ck::index_t>());
SimpleDeviceMem x(sizeof(XDataType) * numXYElement);
SimpleDeviceMem y(sizeof(YDataType) * numXYElement);
SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem bias(sizeof(BiasDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
PassThrough,
Rank,
NumBatchNormReduceDim>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
bool found = false;
int best_op_index = -1;
float best_ave_time = std::numeric_limits<float>::max();
// profile device operation instances and save the best performant instance index and instance
// name
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
bias.GetDeviceBuffer(),
epsilon,
PassThrough{},
y.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
averageFactor,
nullptr,
nullptr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
if(ave_time < best_ave_time)
{
found = true;
best_op_index = i;
best_ave_time = ave_time;
}
}
}
if(found)
{
instance_index = best_op_index;
instance_name = op_ptrs[instance_index]->GetTypeIdHashCode();
};
// simulate the execution of the operation when the instance index and name are available
const auto op_ptrs_2 = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(instance_index >= 0 && instance_index < op_ptrs_2.size())
{
auto& op_ptr = op_ptrs_2[instance_index];
if(op_ptr->GetTypeIdHashCode() == instance_name)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
reduceDims,
scaleBiasMeanVarLengths,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
scaleBiasMeanVarStrides,
x.GetDeviceBuffer(),
scale.GetDeviceBuffer(),
bias.GetDeviceBuffer(),
epsilon,
PassThrough{},
y.GetDeviceBuffer(),
mean.GetDeviceBuffer(),
invVariance.GetDeviceBuffer(),
averageFactor,
nullptr,
nullptr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float exec_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
size_t num_bytes = numXYElement * (sizeof(XDataType) + sizeof(YDataType)) +
numScaleBiasMeanVarElement *
(sizeof(ScaleDataType) + sizeof(BiasDataType) +
sizeof(MeanVarDataType) + sizeof(MeanVarDataType));
float gb_per_sec = num_bytes / 1.E6 / exec_time;
std::cout << "Kernel execution time: " << std::setw(10) << exec_time
<< " ms, effective data transfer bandwidth: " << gb_per_sec << " GB/s"
<< std::endl;
}
};
}
return 0;
}
## CK docker hub
[Docker hub](https://hub.docker.com/r/rocm/composable_kernel)
## Why do I need this?
To make our lives easier and bring Composable Kernel dependencies together, we recommend using docker images.
## So what is Composable Kernel?
Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++.
To get the CK library
```
git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git
```
run a docker container
```
docker run \
-it \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/composable_kernel:ck_ub20.04_rocm5.3_release \
/bin/bash
```
and build the CK
```
mkdir build && cd build
# Need to specify target ID, example below is for gfx908 and gfx90a
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-O3" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx908;gfx90a" \
..
```
and
```
make -j examples tests
```
To run all the test cases including tests and examples run
```
make test
```
We can also run specific examples or tests like
```
./bin/example_gemm_xdl_fp16
./bin/test_gemm_fp16
```
For more details visit [CK github repo](https://github.com/ROCmSoftwarePlatform/composable_kernel), [CK examples](https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/example), [even more CK examples](https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/client_example).
## And what is inside?
The docker images have everything you need for running CK including:
* [ROCm](https://www.amd.com/en/graphics/servers-solutions-rocm)
* [CMake](https://cmake.org/)
* [Compiler](https://github.com/RadeonOpenCompute/llvm-project)
## Which image is right for me?
Let's take a look at the image naming, for example "ck_ub20.04_rocm5.4_release". The image specs are:
* "ck" - made for running Composable Kernel
* "ub20.04" - based on Ubuntu 20.04
* "rocm5.4" - ROCm platform version 5.4
* "release" - compiler version is release
So just pick the right image for your project dependencies and you're all set.
## DIY starts here
If you need to customize a docker image or just can't stop tinkering, feel free to adjust the [Dockerfile](https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/Dockerfile) for your needs.
## License
CK is released under the MIT [license](https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/LICENSE).
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_2d.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using ADataType = F16;
using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ADataType>,
ck::Tuple<BDataType>,
PassThrough,
3, // NumDim_M
1, // NumDim_N
8,
8,
ck::Sequence<8>,
ck::Sequence<8>>;
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw,
const std::vector<std::size_t>& shape_nchw,
Functor functor)
{
for(std::size_t n = 0; n < shape_nchw[0]; ++n)
for(std::size_t c = 0; c < shape_nchw[1]; ++c)
for(std::size_t h = 0; h < shape_nchw[2]; ++h)
for(std::size_t w = 0; w < shape_nchw[3]; ++w)
{
auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val);
}
}
int main()
{
bool do_verification = true;
bool time_kernel = true;
const int N = 120;
const int C = 128;
const int H = 32;
const int W = 1024;
/**const int N = 120;
const int H = 32;
const int W = 64;
const int C = 128;**/
std::vector<std::size_t> nchw = {N, C, H, W};
std::vector<std::size_t> nhwc = {N, H, W, C};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
// LogRangeAsType<float>(std::cout << "Tensor a : ", a.mData, ",") << std::endl;
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths{N, H, W, C};
std::array<ck::index_t, 4> a_strides = {C * H * W, W, 1, H * W};
std::array<ck::index_t, 4> b_strides = {H * W * C, W * C, C, 1};
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
std::cout << "A (nchw): " << a.mDesc << std::endl;
std::cout << "B (nhwc): " << b.mDesc << std::endl;
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer();
float ave_time =
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3];
std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) +
sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
bool pass = true;
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
// LogRangeAsType<float>(std::cout << "Tensor b : ", b.mData, ",") << std::endl;
Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nchw, PassThrough{});
// LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}
return pass ? 0 : 1;
}
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
// check GPU target // check GPU target
#ifdef __HIP_DEVICE_COMPILE__ #ifdef __HIP_DEVICE_COMPILE__
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__)) defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__))
#error Not supported target #error Not supported target
#endif #endif
#endif #endif
...@@ -38,6 +38,8 @@ ...@@ -38,6 +38,8 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000
#endif #endif
// FMA instruction // FMA instruction
...@@ -62,6 +64,13 @@ ...@@ -62,6 +64,13 @@
#define CK_USE_AMD_MFMA_BF16_1K_OP #define CK_USE_AMD_MFMA_BF16_1K_OP
#endif #endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load // buffer load
#define CK_USE_AMD_BUFFER_LOAD 1 #define CK_USE_AMD_BUFFER_LOAD 1
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <sstream>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
...@@ -46,6 +47,17 @@ struct BaseOperator ...@@ -46,6 +47,17 @@ struct BaseOperator
virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; } virtual std::string GetTypeString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
virtual std::string GetTypeIdHashCode() const
{
std::ostringstream oss;
oss << std::hex << typeid(*this).hash_code();
return oss.str();
};
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim_m,
index_t NumDim_n,
index_t MPerThread,
index_t NPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n;
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
template <typename Desc_MN>
static auto PadDescriptor_MN_2d(Desc_MN desc_mn,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
std::ignore = blockSize;
std::ignore = gridSize;
const auto m = desc_mn.GetLength(I0);
const auto n = desc_mn.GetLength(I1);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
const auto desc_mn_pad = transform_tensor_descriptor(
desc_mn,
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return desc_mn_pad;
}
static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& stride,
index_t gridSize,
index_t blockSize,
index_t num_threads_m,
index_t num_threads_n)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
// merge nd to 2d desc - [s0 * s1 * ...]
if constexpr(NumDim > 2)
{
const auto desc_mn = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n);
}
else
return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n);
}
template <index_t TupleSize>
static auto GenerateInOutGrid2dDescTuple(Number<TupleSize>)
{
return generate_tuple(
[&](auto) {
if constexpr(NumDim > 2)
{
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1);
}
else
{
return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1);
};
},
Number<TupleSize>{});
};
using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple,
OutGrid2dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
MPerThread,
NPerThread,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
blockSize_(256),
gridSize_(120), // FIXME - Calculate the grid size by number of CU in the future
num_threads_m_((gridSize_ * blockSize_) / 16),
num_threads_n_(16)
{
static_assert(NumDim_m > 0, "");
static_assert(NumDim_n > 0, "");
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
in_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(lengths,
inStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
},
Number<NumInput>{});
out_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeDescriptor_MN(lengths,
outStridesArray[I.value],
gridSize_,
blockSize_,
num_threads_m_,
num_threads_n_);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
InGrid2dDescTuple in_grid_2d_desc_tuple_;
OutGrid2dDescTuple out_grid_2d_desc_tuple_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
index_t blockSize_;
index_t gridSize_;
index_t num_threads_m_;
index_t num_threads_n_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_elementwise_2d<GridwiseElementwise,
InGrid2dDescTuple,
OutGrid2dDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.in_grid_2d_desc_tuple_,
arg.out_grid_2d_desc_tuple_,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
arg.num_threads_m_,
arg.num_threads_n_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t vectorDim) {
if(strides[vectorDim] == 1 &&
(lengths[vectorDim] % scalarPerVector == 0 ||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
{
return true;
}
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
{
return true;
}
return false;
};
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->inStridesArray_[I.value],
InScalarPerVectorSeq::At(I),
NumDim_m - 1))
valid = false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(pArg->lengths_,
pArg->outStridesArray_[I.value],
OutScalarPerVectorSeq::At(I),
NumDim - 1))
valid = false;
});
return valid;
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{K}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
...@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
...@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(arg.kraw_ % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if 0
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
...@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< " ) " << std::endl; << " ) " << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -265,7 +265,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -265,7 +265,8 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
N01_{N01}, N01_{N01},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{K}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
...@@ -299,6 +300,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -299,6 +300,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -443,6 +445,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -443,6 +445,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return false; return false;
} }
if(arg.kraw_ % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -422,7 +422,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -422,7 +422,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
kraw_{KRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -448,6 +449,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -448,6 +449,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t kraw_;
}; };
// Invoker // Invoker
...@@ -578,6 +580,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -578,6 +580,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return false; return false;
} }
if((arg.kraw_ % AK1 != 0 || arg.kraw_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}); });
} }
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise2dFunctor,
typename InGrid2dDescTuple,
typename OutGrid2dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation>
__global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tuple,
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
{
GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple,
out_grid_2d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
num_threads_m,
num_threads_n);
}
template <typename InGrid2dDescTuple,
typename OutGrid2dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
index_t MPerThread,
index_t NPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_2D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid2dDescTuple::Size() &&
NumOutput == OutGrid2dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto thread_buffer_desc_mn =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}, Number<NPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple,
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
{
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread,
true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread,
true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_2d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
index_t tid_m = thread_1d_id / num_threads_n;
index_t tid_n = thread_1d_id % num_threads_n;
const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
decltype(in_grid_2d_desc_tuple[I]),
decltype(thread_buffer_desc_mn),
Sequence<MPerThread, NPerThread>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(I), // ScalarPerVector
1, // SrcScalarStrideInVector
true>{in_grid_2d_desc_tuple[I], thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<
DataType,
DataType,
decltype(thread_buffer_desc_mn),
decltype(out_grid_2d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread, NPerThread>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
true>(out_grid_2d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter_m = M / (loop_step_m);
do
{
index_t num_iter_n = N / (loop_step_n);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_mn,
make_tuple(I0, I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I],
make_multi_index(0, loop_step_n));
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) {
constexpr auto offset =
thread_buffer_desc_mn.CalculateOffset(make_tuple(iM, iN));
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& {
return in_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumInput>{});
// get referenec to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& {
return out_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mn,
make_tuple(I0, I0),
out_thread_buf_tuple[I],
out_grid_2d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_2d_desc_tuple[I],
make_multi_index(0, loop_step_n));
});
} while(--num_iter_n);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_2d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_2d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n));
});
} while(--num_iter_m);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "data_type.hpp"
// TODO: Add arch limitation
namespace ck {
// wave32 only
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;
template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck
#endif
...@@ -114,7 +114,16 @@ static inline __device__ int4_t abs(int4_t x) ...@@ -114,7 +114,16 @@ static inline __device__ int4_t abs(int4_t x)
}; };
#endif #endif
static inline __device__ half_t abs(half_t x) { return ::__habs(x); }; static inline __device__ half_t abs(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x;
};
static inline __device__ bool isnan(float x) { return ::isnan(x); }; static inline __device__ bool isnan(float x) { return ::isnan(x); };
...@@ -140,7 +149,12 @@ static inline __device__ bool isnan(int4_t x) ...@@ -140,7 +149,12 @@ static inline __device__ bool isnan(int4_t x)
}; };
#endif #endif
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); }; static inline __device__ bool isnan(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ float sqrt(float x) { return ::sqrtf(x); }; static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
......
...@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>; ...@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances =
...@@ -71,12 +72,36 @@ using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = ...@@ -71,12 +72,36 @@ using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances =
// clang-format on // clang-format on
>; >;
// irregular tile size
using device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | |
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>
#endif
// clang-format on
>;
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances) instances)
{ {
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>; ...@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances =
...@@ -71,12 +72,36 @@ using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = ...@@ -71,12 +72,36 @@ using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances =
// clang-format on // clang-format on
>; >;
// irregular tile size
using device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | |
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>
#endif
// clang-format on
>;
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances) instances)
{ {
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>; ...@@ -25,7 +25,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
...@@ -98,12 +99,36 @@ using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = ...@@ -98,12 +99,36 @@ using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
// clang-format on // clang-format on
>; >;
// irregular tile size
using device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | |
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | |
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>
#endif
// clang-format on
>;
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances) instances)
{ {
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment