Commit b30d416c authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 2fd6c6d4 94fbaac0
...@@ -55,7 +55,7 @@ __global__ void ...@@ -55,7 +55,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -49,8 +49,7 @@ __global__ void ...@@ -49,8 +49,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -75,7 +74,7 @@ __global__ void ...@@ -75,7 +74,7 @@ __global__ void
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
template <index_t BlockSize, template <index_t BlockSize,
......
...@@ -25,7 +25,7 @@ __global__ void ...@@ -25,7 +25,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
......
...@@ -26,7 +26,7 @@ __global__ void ...@@ -26,7 +26,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -54,7 +54,7 @@ __global__ void ...@@ -54,7 +54,7 @@ __global__ void
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......
...@@ -58,7 +58,7 @@ __global__ void ...@@ -58,7 +58,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue // TODO ANT: separate into MMA + Epilogue
......
...@@ -167,7 +167,7 @@ __global__ void ...@@ -167,7 +167,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -45,7 +45,7 @@ __global__ void ...@@ -45,7 +45,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
......
...@@ -38,7 +38,7 @@ __global__ void ...@@ -38,7 +38,7 @@ __global__ void
typename GridwiseGemm::Block2CTileMap block_mapping) typename GridwiseGemm::Block2CTileMap block_mapping)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
......
...@@ -39,7 +39,7 @@ __global__ void ...@@ -39,7 +39,7 @@ __global__ void
const CGridDesc_M_N c_grid_desc_m_n) const CGridDesc_M_N c_grid_desc_m_n)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -70,7 +70,7 @@ __global__ void ...@@ -70,7 +70,7 @@ __global__ void
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_k0_m_k1 = const auto a_grid_desc_k0_m_k1 =
......
...@@ -43,7 +43,7 @@ __global__ void ...@@ -43,7 +43,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......
...@@ -37,7 +37,7 @@ __global__ void ...@@ -37,7 +37,7 @@ __global__ void
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
......
...@@ -47,7 +47,7 @@ __global__ void ...@@ -47,7 +47,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
......
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
...@@ -54,7 +54,7 @@ __global__ void ...@@ -54,7 +54,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
...@@ -36,8 +36,7 @@ __global__ void ...@@ -36,8 +36,7 @@ __global__ void
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel::Run(in_grid_desc, GridwiseTensorRearrangeKernel::Run(in_grid_desc,
p_in_global, p_in_global,
out_grid_desc, out_grid_desc,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/math.hpp"
namespace ck {
namespace lds_utils {
/** \brief Allocate a given number of buffers in LDS and return them as a tuple.
*
* \tparam DataType Data type of elements to be stored in LDS.
* \tparam NumBuffers Number of buffers to be allocated.
* \param lds_ptr Address of the beginning of LDS space.
* \param num_elems_per_buffer Number of elements to allocate per single buffer.
* \param start_offset_elems Number of elements to move from the start of LDS for the allocation of
* the first buffer. \param lds_alignment Alignment of every buffer allocation given as a number of
* elements. \return Tuple of dynamic buffers representing memory allocated in LDS.
*/
template <typename DataType, index_t NumBuffers>
__device__ static auto AllocateLdsBuffers(void* lds_ptr,
int32_t num_elems_per_buffer,
int32_t start_offset_elems,
int32_t lds_alignment)
{
const DataType* lds_start = static_cast<DataType*>(lds_ptr) + start_offset_elems;
const int32_t single_buffer_offset =
math::integer_least_multiple(num_elems_per_buffer, lds_alignment);
return generate_tuple(
[&](auto i) {
const int32_t local_offset = i * single_buffer_offset;
return make_dynamic_buffer<AddressSpaceEnum::Lds>(lds_start + local_offset,
num_elems_per_buffer);
},
Number<NumBuffers>{});
}
} // namespace lds_utils
} // namespace ck
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
// TODO: Add arch limitation // TODO: Add arch limitation
namespace ck { namespace ck {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
/********************************WAVE32 MODE***********************************************/ /********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32 // src: fp16, dst: fp32
...@@ -25,7 +28,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -25,7 +28,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
// delete them. // delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32( // amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else #else
...@@ -46,7 +49,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> ...@@ -46,7 +49,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<float8_t>()(Number<0>{}) = reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
...@@ -71,7 +74,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> ...@@ -71,7 +74,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
#else #else
...@@ -95,7 +98,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> ...@@ -95,7 +98,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<bhalf16_t>()(Number<0>{}) = reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
...@@ -117,7 +120,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -117,7 +120,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<int32x8_t>()(Number<0>{}) = reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a, neg_a,
...@@ -145,7 +148,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> ...@@ -145,7 +148,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
#else #else
...@@ -166,7 +169,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> ...@@ -166,7 +169,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
...@@ -191,7 +194,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> ...@@ -191,7 +194,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
#else #else
...@@ -215,7 +218,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> ...@@ -215,7 +218,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<bhalf8_t>()(Number<0>{}) = reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel); reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
...@@ -237,7 +240,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -237,7 +240,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template <class FloatC> template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #if defined(__gfx11__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
neg_a, neg_a,
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
#pragma once #pragma once
namespace ck { namespace ck {
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
// fp32 // fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
...@@ -341,7 +345,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -341,7 +345,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx90a__) || defined(__gfx94__)
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else #else
...@@ -361,7 +365,7 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> ...@@ -361,7 +365,7 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(reg_a), bit_cast<long>(reg_a),
...@@ -393,7 +397,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -393,7 +397,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(reg_a), bit_cast<long>(reg_a),
bit_cast<long>(reg_b), bit_cast<long>(reg_b),
...@@ -424,7 +428,7 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> ...@@ -424,7 +428,7 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(reg_a), bit_cast<long>(reg_a),
...@@ -456,7 +460,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> ...@@ -456,7 +460,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(reg_a), bit_cast<long>(reg_a),
bit_cast<long>(reg_b), bit_cast<long>(reg_b),
...@@ -487,7 +491,7 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32> ...@@ -487,7 +491,7 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(reg_a), bit_cast<long>(reg_a),
...@@ -519,7 +523,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> ...@@ -519,7 +523,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(reg_a), bit_cast<long>(reg_a),
bit_cast<long>(reg_b), bit_cast<long>(reg_b),
...@@ -550,7 +554,7 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32> ...@@ -550,7 +554,7 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(reg_a), bit_cast<long>(reg_a),
...@@ -582,7 +586,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> ...@@ -582,7 +586,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(reg_a), bit_cast<long>(reg_a),
bit_cast<long>(reg_b), bit_cast<long>(reg_b),
......
...@@ -8,6 +8,10 @@ ...@@ -8,6 +8,10 @@
#include "ck/utility/random_gen.hpp" #include "ck/utility/random_gen.hpp"
namespace ck { namespace ck {
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
...@@ -105,7 +109,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -105,7 +109,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{ {
constexpr int seed = 42; constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
float max_fp8 = 240.0f; float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union union
...@@ -133,7 +137,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -133,7 +137,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x)); return f8_convert_sr<f8_t>(type_convert<float>(x));
#else #else
...@@ -154,7 +158,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) ...@@ -154,7 +158,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{ {
constexpr int seed = 42; constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
union union
{ {
float fval; float fval;
...@@ -180,7 +184,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) ...@@ -180,7 +184,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x)); return f8_convert_sr<bf8_t>(type_convert<float>(x));
#else #else
...@@ -203,7 +207,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); ...@@ -203,7 +207,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
float max_fp8 = 240.0f; float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union union
...@@ -232,7 +236,7 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) ...@@ -232,7 +236,7 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x) inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x)); return f8_convert_rne<f8_t>(type_convert<float>(x));
#else #else
...@@ -250,7 +254,7 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x) ...@@ -250,7 +254,7 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x)
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
union union
{ {
float fval; float fval;
...@@ -277,7 +281,7 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) ...@@ -277,7 +281,7 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x)); return f8_convert_rne<bf8_t>(type_convert<float>(x));
#else #else
...@@ -306,7 +310,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) ...@@ -306,7 +310,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
template <> template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
float fval; float fval;
uint32_t i32val = static_cast<uint32_t>(x); uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
...@@ -321,7 +325,7 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) ...@@ -321,7 +325,7 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
template <> template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x); const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else #else
...@@ -363,7 +367,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) ...@@ -363,7 +367,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
...@@ -387,7 +391,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) ...@@ -387,7 +391,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
template <> template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
float fval; float fval;
uint32_t i32val = static_cast<uint32_t>(x); uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
...@@ -414,7 +418,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) ...@@ -414,7 +418,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment