Unverified Commit 06b891c5 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

aggregate device macros in ck_tile config header (#1297)

parent 1274861a
...@@ -53,8 +53,7 @@ __global__ void ...@@ -53,8 +53,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -45,8 +45,7 @@ __global__ void ...@@ -45,8 +45,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t KBatch = 1; const index_t KBatch = 1;
......
...@@ -50,8 +50,7 @@ __global__ void ...@@ -50,8 +50,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -80,7 +79,7 @@ __global__ void ...@@ -80,7 +79,7 @@ __global__ void
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Assume B is Col-Major // Assume B is Col-Major
......
...@@ -34,8 +34,7 @@ __global__ void ...@@ -34,8 +34,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
...@@ -48,7 +47,7 @@ __global__ void ...@@ -48,7 +47,7 @@ __global__ void
karg); karg);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -63,8 +62,7 @@ __global__ void ...@@ -63,8 +62,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -81,7 +79,7 @@ __global__ void ...@@ -81,7 +79,7 @@ __global__ void
karg); karg);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename ALayout, template <typename ALayout,
......
...@@ -33,8 +33,7 @@ __global__ void ...@@ -33,8 +33,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
...@@ -49,7 +48,7 @@ __global__ void ...@@ -49,7 +48,7 @@ __global__ void
karg.c_element_op); karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -64,8 +63,7 @@ __global__ void ...@@ -64,8 +63,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -84,7 +82,7 @@ __global__ void ...@@ -84,7 +82,7 @@ __global__ void
karg.c_element_op); karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename ALayout, template <typename ALayout,
......
...@@ -38,8 +38,7 @@ __global__ void ...@@ -38,8 +38,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
...@@ -52,7 +51,7 @@ __global__ void ...@@ -52,7 +51,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <index_t BlockSize, template <index_t BlockSize,
......
...@@ -3,6 +3,21 @@ ...@@ -3,6 +3,21 @@
#pragma once #pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
...@@ -109,15 +124,13 @@ ...@@ -109,15 +124,13 @@
// buffer atomic add: floating point // buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #elif defined(__gfx9__) // for GPU code
defined(__gfx942__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code #else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif #endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
defined(__gfx942__)) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else #else
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 #define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
...@@ -137,13 +150,12 @@ ...@@ -137,13 +150,12 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ defined(__gfx9__) // for GPU code
defined(__gfx942__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #elif defined(__gfx11__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
......
...@@ -55,7 +55,7 @@ struct alignas(1) float8_e4m3_t ...@@ -55,7 +55,7 @@ struct alignas(1) float8_e4m3_t
{ {
static constexpr int exponent = 4; static constexpr int exponent = 4;
static constexpr int mantissa = 3; static constexpr int mantissa = 3;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
static constexpr int bias = 1 << (exponent - 1); // NANOO static constexpr int bias = 1 << (exponent - 1); // NANOO
#else #else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
...@@ -113,7 +113,7 @@ struct alignas(1) float8_e5m2_t ...@@ -113,7 +113,7 @@ struct alignas(1) float8_e5m2_t
{ {
static constexpr int exponent = 5; static constexpr int exponent = 5;
static constexpr int mantissa = 2; static constexpr int mantissa = 2;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
static constexpr int bias = 1 << (exponent - 1); // NANOO static constexpr int bias = 1 << (exponent - 1); // NANOO
#else #else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
...@@ -470,7 +470,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x) ...@@ -470,7 +470,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
{ {
constexpr int seed = 42; constexpr int seed = 42;
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
float max_fp8 = 240.0f; float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union union
...@@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) ...@@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
{ {
constexpr int seed = 42; constexpr int seed = 42;
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
union union
{ {
float fval; float fval;
...@@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) ...@@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
float max_fp8 = 240.0f; float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union union
...@@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) ...@@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
} }
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
union union
{ {
float fval; float fval;
...@@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>) ...@@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
float fval; float fval;
uint32_t i32val = static_cast<uint32_t>(x); uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
...@@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) ...@@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
float fval; float fval;
uint32_t i32val = static_cast<uint32_t>(x); uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
...@@ -656,7 +656,7 @@ struct numeric_traits<fp8_t> ...@@ -656,7 +656,7 @@ struct numeric_traits<fp8_t>
{ {
static constexpr int exp = 4; static constexpr int exp = 4;
static constexpr int mant = 3; static constexpr int mant = 3;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
static constexpr int bias = 8; static constexpr int bias = 8;
#else #else
static constexpr int bias = 7; static constexpr int bias = 7;
...@@ -668,7 +668,7 @@ struct numeric_traits<bf8_t> ...@@ -668,7 +668,7 @@ struct numeric_traits<bf8_t>
{ {
static constexpr int exp = 5; static constexpr int exp = 5;
static constexpr int mant = 2; static constexpr int mant = 2;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
static constexpr int bias = 16; static constexpr int bias = 16;
#else #else
static constexpr int bias = 15; // IEEE static constexpr int bias = 15; // IEEE
......
...@@ -112,7 +112,7 @@ namespace impl { ...@@ -112,7 +112,7 @@ namespace impl {
template <typename OutDataType, typename InTensor> template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function // This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
......
...@@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
CK_TILE_DEVICE void CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if defined(__gfx9__)
defined(__gfx942__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ck_tile::ignore = c_vec; ck_tile::ignore = c_vec;
...@@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if defined(__gfx9__)
defined(__gfx942__)
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
#else #else
...@@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
CK_TILE_DEVICE void CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if defined(__gfx9__)
defined(__gfx942__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ck_tile::ignore = c_vec; ck_tile::ignore = c_vec;
...@@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if defined(__gfx9__)
defined(__gfx942__)
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else #else
...@@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
CK_TILE_DEVICE void CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__) #elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) { static_for<0, 2, 1>{}([&](auto k) {
...@@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx90a__) || defined(__gfx94__)
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
#elif defined(__gfx908__) #elif defined(__gfx908__)
...@@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
CK_TILE_DEVICE void CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__) #elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) { static_for<0, 2, 1>{}([&](auto k) {
...@@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx90a__) || defined(__gfx94__)
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#elif defined(__gfx908__) #elif defined(__gfx908__)
...@@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
CK_TILE_DEVICE void CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>) if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
...@@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>) if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0)); bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment