// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP #include "ck/utility/amd_inline_asm.hpp" #include "data_type.hpp" // TODO: Add arch limitation namespace ck { /********************************WAVE32 MODE***********************************************/ // src: fp16, dst: fp32 template struct intrin_wmma_f32_16x16x16_f16_w32; template <> struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> { template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { // * Inline assembly need to elimate the duplicated data load, compiler won't help you // delete them. // amd_assembly_wmma_f32_16x16x16_f16_w32( // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; // src: bf16, dst: fp32 template struct intrin_wmma_f32_16x16x16_bf16_w32; template <> struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> { template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; // src: fp16, dst: fp16 template struct intrin_wmma_f16_16x16x16_f16_w32; template struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> { template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; // src: bf16, dst: bf16 template struct intrin_wmma_bf16_16x16x16_bf16_w32; template struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> { template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; // src: iu8, dst: i32 template struct intrin_wmma_i32_16x16x16_iu8_w32; template struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> { template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( neg_a, bit_cast(reg_a), neg_b, bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], clamp); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; /********************************WAVE64 MODE***********************************************/ template struct intrin_wmma_f32_16x16x16_f16_w64; template <> struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> { template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; // src: bf16, dst: fp32 template struct intrin_wmma_f32_16x16x16_bf16_w64; template <> struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> { template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; // src: fp16, dst: fp16 template struct intrin_wmma_f16_16x16x16_f16_w64; template struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> { template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; // src: bf16, dst: bf16 template struct intrin_wmma_bf16_16x16x16_bf16_w64; template struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> { template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; // src: iu8, dst: i32 template struct intrin_wmma_i32_16x16x16_iu8_w64; template struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> { template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( neg_a, bit_cast(reg_a), neg_b, bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], clamp); #else ignore = reg_a; ignore = reg_b; ignore = reg_c; #endif } }; } // namespace ck #endif