Commit 478df149 authored by fsx950223's avatar fsx950223
Browse files

Merge remote-tracking branch 'origin/develop' into embeddings

parents 8941136f 80e05267
...@@ -500,6 +500,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -500,6 +500,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
#if DEBUG_LOG
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{" std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", " << ", "
...@@ -520,6 +521,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -520,6 +521,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", " << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}" << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl; << std::endl;
#endif
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_, arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
......
...@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for<0, NumReduction, 1>{}([&](auto I) { static_for<0, NumReduction, 1>{}([&](auto I) {
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>; using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
flag = flag =
flag && ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation, flag && ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value; OutDataType>::value;
}); });
return flag; return flag;
......
...@@ -40,8 +40,16 @@ template <typename InDataType, ...@@ -40,8 +40,16 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceMultiBlock struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation> AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock ...@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock
static constexpr bool use_multiblock = static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation, static_assert(ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value, OutDataType>::value,
"The OutDataType must support the specified OutMemoryDataOperation!"); "The OutDataType must support the specified OutMemoryDataOperation!");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex), static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
......
...@@ -35,8 +35,17 @@ template <typename InDataType, ...@@ -35,8 +35,17 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWise struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation> AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
......
...@@ -434,7 +434,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk ...@@ -434,7 +434,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
......
This diff is collapsed.
...@@ -319,7 +319,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -319,7 +319,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / __builtin_amdgcn_sqrtf(var_thread_buf(iM) + epsilon); auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
......
This diff is collapsed.
...@@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3); c3);
} }
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
#ifndef CK_AMD_WMMA_HPP #ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP
#include "ck/utility/amd_inline_asm.hpp"
#include "data_type.hpp" #include "data_type.hpp"
// TODO: Add arch limitation // TODO: Add arch limitation
namespace ck { namespace ck {
// wave32 only /********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32 // src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32; struct intrin_wmma_f32_16x16x16_f16_w32;
...@@ -19,8 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -19,8 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( // * Inline assembly need to elimate the duplicated data load, compiler won't help you
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); // delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]);
} }
}; };
...@@ -98,5 +105,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -98,5 +105,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
} }
}; };
/********************************WAVE64 MODE***********************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;
template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w64;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
#pragma once #pragma once
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath> #include <cmath>
#endif
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment