Unverified Commit 909f519c authored by Harisankar Sadasivan's avatar Harisankar Sadasivan Committed by GitHub
Browse files

Merge branch 'develop' into universal_streamk

parents 406fa265 3bb0fe6c
...@@ -107,7 +107,7 @@ __global__ void ...@@ -107,7 +107,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx11__)) defined(__gfx11__) || defined(__gfx12__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported())) ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{ {
return false; return false;
} }
......
...@@ -582,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -582,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -39,8 +39,9 @@ __global__ void ...@@ -39,8 +39,9 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
...@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
} }
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
......
...@@ -61,7 +61,7 @@ __global__ void ...@@ -61,7 +61,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -166,6 +166,7 @@ __global__ void ...@@ -166,6 +166,7 @@ __global__ void
ignore = O; ignore = O;
ignore = G0; ignore = G0;
ignore = G1; ignore = G1;
ignore = alpha;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx11__)) #endif // end of if (defined(__gfx11__))
...@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -60,7 +60,7 @@ __global__ void ...@@ -60,7 +60,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -165,6 +165,7 @@ __global__ void ...@@ -165,6 +165,7 @@ __global__ void
ignore = O; ignore = O;
ignore = G0; ignore = G0;
ignore = G1; ignore = G1;
ignore = alpha;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx11__)) #endif // end of if (defined(__gfx11__))
...@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma ...@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_gfx11_supported()) if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B0EnableLds) if constexpr(B0EnableLds)
{ {
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
B0BlockDesc_{}, B0BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})), Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
...@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if constexpr(B1EnableLds) if constexpr(B1EnableLds)
{ {
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_LRow = I2;
#else
constexpr auto B_LRow = I1; constexpr auto B_LRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
B1BlockDesc_{}, B1BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_L0 / B_LRow>{}, B_LRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_L1>{})), make_pass_through_transform(Number<B_L1>{})),
......
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1; constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
ABlockDesc_{}, ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})), Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})), make_pass_through_transform(Number<A_K1>{})),
...@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1; constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_{}, BBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)), make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})), Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})), make_pass_through_transform(Number<B_K1>{})),
......
...@@ -35,8 +35,9 @@ __global__ void ...@@ -35,8 +35,9 @@ __global__ void
const Block2ETileMap block_2_tile_map, const Block2ETileMap block_2_tile_map,
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
GridwiseTensorRearrangeKernel::Run(in_grid_desc, GridwiseTensorRearrangeKernel::Run(in_grid_desc,
p_in_global, p_in_global,
out_grid_desc, out_grid_desc,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#pragma once
namespace ck {
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_16x16x32f16;
template <>
struct intrin_smfmac_f32_16x16x32f16<16, 16>
{
template <class FloatC>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_16x16x32bf16;
template <>
struct intrin_smfmac_f32_16x16x32bf16<16, 16>
{
template <class FloatC>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_32x32x16f16;
template <>
struct intrin_smfmac_f32_32x32x16f16<32, 32>
{
template <class FloatC>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_32x32x16bf16;
template <>
struct intrin_smfmac_f32_32x32x16bf16<32, 32>
{
template <class FloatC>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
}
};
} // namespace ck
This diff is collapsed.
...@@ -203,7 +203,7 @@ struct vector_type<T, 1> ...@@ -203,7 +203,7 @@ struct vector_type<T, 1>
} }
}; };
int static err = 0; __device__ int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2>
{ {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment