Commit d21ab0f5 authored by fengzch's avatar fengzch
Browse files

fix: use rocm

parent 181f4e43
...@@ -4,5 +4,6 @@ source /usr/local/bin/fastpt -T ...@@ -4,5 +4,6 @@ source /usr/local/bin/fastpt -T
export CPLUS_INCLUDE_PATH=/opt/dtk/roctracer/include:$CPLUS_INCLUDE_PATH export CPLUS_INCLUDE_PATH=/opt/dtk/roctracer/include:$CPLUS_INCLUDE_PATH
export AMDGPU_TARGETS="gfx906;gfx926;gfx928;gfx936" export AMDGPU_TARGETS="gfx906;gfx926;gfx928;gfx936"
export FASTPT_USE_ASM=1
CXX=hipcc CC=hipcc python setup.py bdist_wheel CXX=hipcc CC=hipcc python setup.py bdist_wheel
...@@ -308,7 +308,7 @@ def check_hardware_compatibility(quantization_config: dict, device: str | torch. ...@@ -308,7 +308,7 @@ def check_hardware_compatibility(quantization_config: dict, device: str | torch.
if sm == "120": # you can only use the fp4 models if sm == "120": # you can only use the fp4 models
if quantization_config["weight"]["dtype"] != "fp4_e2m1_all": if quantization_config["weight"]["dtype"] != "fp4_e2m1_all":
raise ValueError('Please use "fp4" quantization for Blackwell GPUs. ') raise ValueError('Please use "fp4" quantization for Blackwell GPUs. ')
elif sm in ["75", "80", "86", "89"]: elif sm in ["75", "80", "86", "89", "92", "93"]:
if quantization_config["weight"]["dtype"] != "int4": if quantization_config["weight"]["dtype"] != "int4":
raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs. ') raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs. ')
else: else:
......
...@@ -12,6 +12,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl ...@@ -12,6 +12,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl
#pragma once #pragma once
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cstdint> #include <cstdint>
__forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) { __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
...@@ -67,12 +68,17 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin ...@@ -67,12 +68,17 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
// Finally, we construct the output numbers. // Finally, we construct the output numbers.
// Convert elt_01 // Convert elt_01
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); // asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
h[0] = __hsub(h[0], __float2half(1024.0f));
// Convert elt_23 // Convert elt_23
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); // asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h[1] = __hfma(h[1], __float2half(0.0625f), __float2half(-64.0f));
// Convert elt_45 // Convert elt_45
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); // asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h[2] = __hsub(h[2], __float2half(1024.0f));
// Convert elt_67 // Convert elt_67
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); // asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h[3] = __hfma(h[3], __float2half(0.0625f), __float2half(-64.0f));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) { __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
...@@ -121,11 +127,16 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so ...@@ -121,11 +127,16 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
// Finally, we construct the output numbers. // Finally, we construct the output numbers.
// Convert elt_01 // Convert elt_01
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(BF16_ONE), "r"(BF16_BIAS)); // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_23 // // Convert elt_23
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(BF16_ONE), "r"(BF16_BIAS)); // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_45 // // Convert elt_45
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS)); // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_67 // // Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS)); // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
h[0] = __hsub(h[0], __float2bfloat16_rn(128.0f));
h[1] = __hsub(h[1], __float2bfloat16_rn(128.0f));
h[2] = __hsub(h[2], __float2bfloat16_rn(128.0f));
h[3] = __hsub(h[3], __float2bfloat16_rn(128.0f));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
...@@ -81,10 +81,10 @@ __device__ void sync_slice(int slice_id) { ...@@ -81,10 +81,10 @@ __device__ void sync_slice(int slice_id) {
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) { __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
uint32_t smem_int_ptr; uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" // asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr) // : "=r"(smem_int_ptr)
: "l"(ptr)); // : "l"(ptr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return smem_int_ptr; return smem_int_ptr;
} }
...@@ -92,38 +92,41 @@ template<typename f16_t> ...@@ -92,38 +92,41 @@ template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) { __inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value, static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types."); "ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16" // asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];" // "{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), // : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), // "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), // "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) // "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr)); // : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
template<typename f16_t> template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) { __inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value, static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types."); "ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" // asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];" // "{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), // : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), // "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), // "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) // "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr)); // : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) { __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) {
const int cp_size = 16; const int cp_size = 16;
asm volatile("{" // asm volatile("{"
" .reg .pred p;" // " .reg .pred p;"
" setp.ne.b32 p, %0, 0;" // " setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;" // " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask), // "}" ::"r"((int)mask),
"r"(smem_int_ptr), // "r"(smem_int_ptr),
"l"(src), // "l"(src),
"n"(cp_size)); // "n"(cp_size));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
template<typename f16_t> template<typename f16_t>
...@@ -131,39 +134,41 @@ __device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16 ...@@ -131,39 +134,41 @@ __device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16
template<> template<>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp) { __device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp) {
asm volatile( // asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" // "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" // "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3]) // : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), // : "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]), // "r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]), // "r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]), // "r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]), // "r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]), // "r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]), // "f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]), // "f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]), // "f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3])); // "f"(((float *)C_warp)[3]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
template<> template<>
__device__ __inline__ void __device__ __inline__ void
mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) { mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) {
asm volatile( // asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" // "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" // "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3]) // : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), // : "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]), // "r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]), // "r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]), // "r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]), // "r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]), // "r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]), // "f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]), // "f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]), // "f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3])); // "f"(((float *)C_warp)[3]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES> template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
...@@ -944,10 +949,11 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A, ...@@ -944,10 +949,11 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
int M, int M,
int N, int N,
int K) { int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 //#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch(); // trap_unsupported_arch();
return; // return;
#endif //#endif
// printf("LOG(INFO) %s: %d %s\n", __FILE__, __LINE__, __func__);
using f162_t = typename packed_as<f16_t, 2>::type; using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
......
...@@ -171,7 +171,7 @@ inline __device__ T ldg(const T *val) { ...@@ -171,7 +171,7 @@ inline __device__ T ldg(const T *val) {
#define float22bf162 __float22bfloat162_rn #define float22bf162 __float22bfloat162_rn
#define bf162bf162 __bfloat162bfloat162 #define bf162bf162 __bfloat162bfloat162
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
float2 f_val; float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f); f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f); f_val.y = max(min(__high2float(val), 127.f), -128.f);
...@@ -203,7 +203,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { ...@@ -203,7 +203,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if ENABLE_BF16 #if ENABLE_BF16
template<> template<>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) { inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
return val[0]; return val[0];
#else #else
return __ldg(val); return __ldg(val);
...@@ -212,7 +212,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) { ...@@ -212,7 +212,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) {
template<> template<>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16 *val) { inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16 *val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
return val[0]; return val[0];
#else #else
return __ldg(val); return __ldg(val);
......
...@@ -191,14 +191,26 @@ public: ...@@ -191,14 +191,26 @@ public:
// set nan values to -inf // set nan values to -inf
__device__ __forceinline__ static half2_t fix_nan(half2_t input) { __device__ __forceinline__ static half2_t fix_nan(half2_t input) {
static constexpr float neginf = -std::numeric_limits<float>::infinity(); // static constexpr float neginf = -std::numeric_limits<float>::infinity();
/** /**
* In accordance to the IEEE-754R standard, * In accordance to the IEEE-754R standard,
* if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN, * if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN,
* but not the other, * but not the other,
* the result is the non-NaN parameter. * the result is the non-NaN parameter.
*/ */
return __hmax2(input, half2_t(neginf, neginf)); // return __hmax2(input, half2_t(neginf, neginf));
half_t lo = __low2half(input);
half_t hi = __high2half(input);
// Step 2: Convert to float to use isnan (HIP supports __hisnan)
// Option A: Use __hisnan if available (preferred)
half_t neg_inf = __float2half(-std::numeric_limits<float>::infinity());
half_t out_lo = __hisnan(lo) ? neg_inf : lo;
half_t out_hi = __hisnan(hi) ? neg_inf : hi;
// Step 3: Pack back into half2_t
return __halves2half2(out_lo, out_hi);
} }
__device__ __forceinline__ static float fix_nan(float input) { __device__ __forceinline__ static float fix_nan(float input) {
...@@ -511,7 +523,8 @@ public: ...@@ -511,7 +523,8 @@ public:
if (alwaysfalse) { if (alwaysfalse) {
dummy = clock(); dummy = clock();
} }
// asm volatile ("membar.cta;"); asm volatile ("membar.cta;");
} }
} }
......
...@@ -50,29 +50,33 @@ template<typename T> ...@@ -50,29 +50,33 @@ template<typename T>
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) { __device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
if constexpr (sizeof(T) == 4) { if constexpr (sizeof(T) == 4) {
uint32_t data; uint32_t data;
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;" // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];" // "@loadpred ld.global.nc.b32 %0, [%1];"
"}" // "}"
: "=r"(data) // : "=r"(data)
: "l"(addr), "r"((int)pred)); // : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return *reinterpret_cast<T *>(&data); return *reinterpret_cast<T *>(&data);
} }
if constexpr (sizeof(T) == 8) { if constexpr (sizeof(T) == 8) {
uint2 data; uint2 data;
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;" // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];" // "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}" // "}"
: "=r"(data.x), "=r"(data.y) // : "=r"(data.x), "=r"(data.y)
: "l"(addr), "r"((int)pred)); // : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return *reinterpret_cast<T *>(&data); return *reinterpret_cast<T *>(&data);
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
uint4 data; uint4 data;
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;" // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];" // "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}" // "}"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) // : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(addr), "r"((int)pred)); // : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return *reinterpret_cast<T *>(&data); return *reinterpret_cast<T *>(&data);
} }
...@@ -88,17 +92,21 @@ __device__ __forceinline__ static void store(T *addr, T val) { ...@@ -88,17 +92,21 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if constexpr (shmem) { if constexpr (shmem) {
if constexpr (sizeof(T) == 8) { if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val); uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile( // asm volatile(
"st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y)); // "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val); uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)), // asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
"r"(data.x), // "r"(data.x),
"r"(data.y), // "r"(data.y),
"r"(data.z), // "r"(data.z),
"r"(data.w)); // "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
*addr = val; *addr = val;
...@@ -127,33 +135,39 @@ template<typename T> ...@@ -127,33 +135,39 @@ template<typename T>
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) { __device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
if constexpr (sizeof(T) == 4) { if constexpr (sizeof(T) == 4) {
uint32_t data = *reinterpret_cast<uint32_t *>(&val); uint32_t data = *reinterpret_cast<uint32_t *>(&val);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" // asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;" // "@storepred st.global.cg.b32 [%1], %2;"
"}" ::"r"((int)pred), // "}" ::"r"((int)pred),
"l"(addr), // "l"(addr),
"r"(data)); // "r"(data));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
if constexpr (sizeof(T) == 8) { if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val); uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" // asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};" // "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}" ::"r"((int)pred), // "}" ::"r"((int)pred),
"l"(addr), // "l"(addr),
"r"(data.x), // "r"(data.x),
"r"(data.y)); // "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val); uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" // asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};" // "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}" ::"r"((int)pred), // "}" ::"r"((int)pred),
"l"(addr), // "l"(addr),
"r"(data.x), // "r"(data.x),
"r"(data.y), // "r"(data.y),
"r"(data.z), // "r"(data.z),
"r"(data.w)); // "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
...@@ -194,14 +208,16 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) { ...@@ -194,14 +208,16 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) { __device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w) : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
: "l"((ptr))); // limengmeng : "l"((ptr)));
} }
template<typename T> template<typename T>
__device__ __forceinline__ static T movmatrix(T x) { __device__ __forceinline__ static T movmatrix(T x) {
asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" // asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(*reinterpret_cast<uint32_t *>(&x)) // : "=r"(*reinterpret_cast<uint32_t *>(&x))
: "r"(*reinterpret_cast<uint32_t *>(&x))); // : "r"(*reinterpret_cast<uint32_t *>(&x)));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return x; return x;
} }
...@@ -215,7 +231,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) { ...@@ -215,7 +231,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
uint32_t result; uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); // asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result; return result;
} }
...@@ -225,7 +243,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) { ...@@ -225,7 +243,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
uint32_t result; uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); // asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result; return result;
} }
...@@ -235,22 +255,27 @@ __device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) { ...@@ -235,22 +255,27 @@ __device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
uint32_t result; uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); // asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result; return result;
} }
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) { __device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
uint32_t result; uint32_t result;
asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }" // asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
: "=r"(result) // : "=r"(result)
: "f"(value.y), "f"(value.x)); // : "f"(value.y), "f"(value.x));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result; return result;
} }
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) { __device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
uint16_t lo, hi; uint16_t lo, hi;
asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x)); // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z)); // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return uint32_t(lo) | (uint32_t(hi) << 16); return uint32_t(lo) | (uint32_t(hi) << 16);
} }
...@@ -347,15 +372,17 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf ...@@ -347,15 +372,17 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf
}; };
__device__ __forceinline__ static void reduce_add(float *addr, float val) { __device__ __forceinline__ static void reduce_add(float *addr, float val) {
asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val)); // asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) { __device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" // asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;" // "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}" ::"r"((int)pred), // "}" ::"r"((int)pred),
"l"(addr), // "l"(addr),
"f"(val)); // "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
template<int cnt, typename F> template<int cnt, typename F>
...@@ -369,9 +396,10 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) { ...@@ -369,9 +396,10 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) {
__device__ __forceinline__ static float int2float_fast(int val) { __device__ __forceinline__ static float int2float_fast(int val) {
float fval; float fval;
// fval = (val & 0x7FFFFF) ^ 0x4B400000 // fval = (val & 0x7FFFFF) ^ 0x4B400000
asm volatile("lop3.b32 %0, %1, %2, %3, %4;" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=f"(fval) // : "=f"(fval)
: "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA)); // : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return fval - 12582912.0f; return fval - 12582912.0f;
} }
...@@ -388,12 +416,13 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) { ...@@ -388,12 +416,13 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
uint32_t ival; uint32_t ival;
uint32_t hval; uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo; // ival.lo = x.lo; ival.hi = y.lo;
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); // asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
ival = ival >> 4; ival = ival >> 4;
// (val & 0x03FF03FF) ^ 0x76007600 // (val & 0x03FF03FF) ^ 0x76007600
asm volatile("lop3.b32 %0, %1, %2, %3, %4;" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval) // : "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA)); // : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
} }
// val in [-4096, 4095], steps of 8, round to nearest // val in [-4096, 4095], steps of 8, round to nearest
...@@ -407,11 +436,12 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) { ...@@ -407,11 +436,12 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
uint32_t hval; uint32_t hval;
// ival.lo = x.hi; ival.hi = y.hi; // ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them // <=> divide x and y by 65536 and pack them
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632)); // asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200 // (val & 0x03FF03FF) ^ 0x72007200
asm volatile("lop3.b32 %0, %1, %2, %3, %4;" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval) // : "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA)); // : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
} }
// val in [-512, 511] // val in [-512, 511]
...@@ -420,11 +450,12 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) { ...@@ -420,11 +450,12 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
uint32_t hval; uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo; // ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them // <=> divide x and y by 65536 and pack them
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); //asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600 // (val & 0x03FF03FF) ^ 0x66006600
asm volatile("lop3.b32 %0, %1, %2, %3, %4;" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval) // : "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA)); // : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
} }
......
...@@ -247,7 +247,7 @@ public: ...@@ -247,7 +247,7 @@ public:
// "r"(wmscale), // "r"(wmscale),
// "n"(0), // "n"(0),
// "h"((short)(idb * 2 + 1))); // "h"((short)(idb * 2 + 1)));
std::cout << __func__ << "mma_fp4 is not implemented for HIP yet[asm error!!!]" << std::endl; // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return out; return out;
} }
...@@ -334,7 +334,8 @@ public: ...@@ -334,7 +334,8 @@ public:
dummy = clock(); dummy = clock();
} }
// asm volatile ("membar.cta;"); asm volatile ("membar.cta;");
} }
} }
...@@ -916,7 +917,9 @@ public: ...@@ -916,7 +917,9 @@ public:
} }
// #endif // #endif
// asm volatile ("membar.cta;"); asm volatile ("membar.cta;");
} }
} }
......
...@@ -10,40 +10,42 @@ public: ...@@ -10,40 +10,42 @@ public:
__device__ __forceinline__ static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) { __device__ __forceinline__ static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
// packed_psum_t psum; // packed_psum_t psum;
asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " // asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," // "{%4, %5, %6, %7},"
"{%8, %9}," // "{%8, %9},"
"{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
: "=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3]) // : "=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
: "r"(act.x), // : "r"(act.x),
"r"(act.y), // "r"(act.y),
"r"(act.z), // "r"(act.z),
"r"(act.w), // "r"(act.w),
"r"(wgt.x), // "r"(wgt.x),
"r"(wgt.y), // "r"(wgt.y),
// "r"(0), "r"(0), "r"(0), "r"(0) // // "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[0]), // "r"(psum.data[0]),
"r"(psum.data[1]), // "r"(psum.data[1]),
"r"(psum.data[2]), // "r"(psum.data[2]),
"r"(psum.data[3])); // "r"(psum.data[3]));
asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " // asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," // "{%4, %5, %6, %7},"
"{%8, %9}," // "{%8, %9},"
"{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
: "=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7]) // : "=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
: "r"(act.x), // : "r"(act.x),
"r"(act.y), // "r"(act.y),
"r"(act.z), // "r"(act.z),
"r"(act.w), // "r"(act.w),
"r"(wgt.z), // "r"(wgt.z),
"r"(wgt.w), // "r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0) // // "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[4]), // "r"(psum.data[4]),
"r"(psum.data[5]), // "r"(psum.data[5]),
"r"(psum.data[6]), // "r"(psum.data[6]),
"r"(psum.data[7])); // "r"(psum.data[7]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return psum; return psum;
} }
...@@ -418,7 +420,8 @@ public: ...@@ -418,7 +420,8 @@ public:
// dummy = clock(); // dummy = clock();
// } // }
// asm volatile ("membar.cta;"); asm volatile ("membar.cta;");
} }
} }
......
...@@ -110,65 +110,65 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, ...@@ -110,65 +110,65 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b,
uint4 d; uint4 d;
static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32; static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 " // asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," // "{%4, %5, %6, %7},"
"{%8, %9}," // "{%8, %9},"
"{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), // : "r"(a.x),
"r"(a.y), // "r"(a.y),
"r"(a.z), // "r"(a.z),
"r"(a.w), // "r"(a.w),
"r"(b.x), // "r"(b.x),
"r"(b.y), // "r"(b.y),
"r"(c.x), // "r"(c.x),
"r"(c.y), // "r"(c.y),
"r"(c.z), // "r"(c.z),
"r"(c.w), // "r"(c.w),
"n"(K), // "n"(K),
"C"(AType::value), // "C"(AType::value),
"C"(BType::value)); // "C"(BType::value));
#else // #else
asm volatile("{" // asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" // ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1}," // "{tmp0, tmp1},"
"{%4}," // "{%4},"
"{%8}," // "{%8},"
"{%10, %11};\n" // "{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3}," // "{tmp2, tmp3},"
"{%5}," // "{%5},"
"{%8}," // "{%8},"
"{%12, %13};\n" // "{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1}," // "{%0, %1},"
"{%6}," // "{%6},"
"{%9}," // "{%9},"
"{tmp0, tmp1};\n" // "{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3}," // "{%2, %3},"
"{%7}," // "{%7},"
"{%9}," // "{%9},"
"{tmp2, tmp3};\n" // "{tmp2, tmp3};\n"
"}\n" // "}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), // : "r"(a.x),
"r"(a.y), // "r"(a.y),
"r"(a.z), // "r"(a.z),
"r"(a.w), // "r"(a.w),
"r"(b.x), // "r"(b.x),
"r"(b.y), // "r"(b.y),
"r"(c.x), // "r"(c.x),
"r"(c.y), // "r"(c.y),
"r"(c.z), // "r"(c.z),
"r"(c.w), // "r"(c.w),
"n"(K / 2), // "n"(K / 2),
"C"(AType::value), // "C"(AType::value),
"C"(BType::value)); // "C"(BType::value));
#endif // #endif
return d; return d;
} }
......
...@@ -36,31 +36,32 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>; ...@@ -36,31 +36,32 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) { __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d; uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " // asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1}," // "{%0, %1},"
"{%2, %3, %4, %5}," // "{%2, %3, %4, %5},"
"{%6, %7}," // "{%6, %7},"
"{%8, %9};\n" // "{%8, %9};\n"
: "=r"(d.x), "=r"(d.y) // : "=r"(d.x), "=r"(d.y)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#else // #else
asm volatile("{" // asm volatile("{"
".reg .b32 tmp0, tmp1;" // ".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " // "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1}," // "{tmp0, tmp1},"
"{%2, %3}," // "{%2, %3},"
"{%6}," // "{%6},"
"{%8, %9};\n" // "{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " // "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1}," // "{%0, %1},"
"{%4, %5}," // "{%4, %5},"
"{%7}," // "{%7},"
"{tmp0, tmp1};" // "{tmp0, tmp1};"
"}\n" // "}\n"
: "=r"(d.x), "=r"(d.y) // : "=r"(d.x), "=r"(d.y)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#endif // #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d; return d;
} }
...@@ -71,13 +72,14 @@ __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 ...@@ -71,13 +72,14 @@ __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2
template<> template<>
__device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) { __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " // asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," // "{%4, %5, %6, %7},"
"{%8, %9}," // "{%8, %9},"
"{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d; return d;
} }
#endif #endif
...@@ -85,31 +87,32 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 ...@@ -85,31 +87,32 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2
template<> template<>
__device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) { __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " // asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," // "{%4, %5, %6, %7},"
"{%8, %9}," // "{%8, %9},"
"{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#else // #else
asm volatile("{" // asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" // ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " // "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3}," // "{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5}," // "{%4, %5},"
"{%8}," // "{%8},"
"{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " // "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
"{%6, %7}," // "{%6, %7},"
"{%9}," // "{%9},"
"{tmp0, tmp1, tmp2, tmp3};" // "{tmp0, tmp1, tmp2, tmp3};"
"}\n" // "}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#endif // #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d; return d;
} }
...@@ -121,7 +124,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe ...@@ -121,7 +124,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
uint4 d; uint4 d;
static constexpr int K = 64; static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile( // asm volatile(
// "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 " // "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
// "{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
...@@ -130,43 +133,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe ...@@ -130,43 +133,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
// "{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else // #else
asm volatile("{" // asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" // ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1}," // "{tmp0, tmp1},"
"{%4}," // "{%4},"
"{%8}," // "{%8},"
"{%10, %11};\n" // "{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3}," // "{tmp2, tmp3},"
"{%5}," // "{%5},"
"{%8}," // "{%8},"
"{%12, %13};\n" // "{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1}," // "{%0, %1},"
"{%6}," // "{%6},"
"{%9}," // "{%9},"
"{tmp0, tmp1};\n" // "{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3}," // "{%2, %3},"
"{%7}," // "{%7},"
"{%9}," // "{%9},"
"{tmp2, tmp3};\n" // "{tmp2, tmp3};\n"
"}\n" // "}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), // : "r"(a.x),
"r"(a.y), // "r"(a.y),
"r"(a.z), // "r"(a.z),
"r"(a.w), // "r"(a.w),
"r"(b.x), // "r"(b.x),
"r"(b.y), // "r"(b.y),
"r"(c.x), // "r"(c.x),
"r"(c.y), // "r"(c.y),
"r"(c.z), // "r"(c.z),
"r"(c.w), // "r"(c.w),
"n"(K / 2)); // "n"(K / 2));
#endif // #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d; return d;
} }
...@@ -175,7 +179,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe ...@@ -175,7 +179,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
uint4 d; uint4 d;
static constexpr int K = 64; static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile( // asm volatile(
// "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 " // "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
// "{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
...@@ -184,43 +188,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe ...@@ -184,43 +188,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
// "{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else // #else
asm volatile("{" // asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" // ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1}," // "{tmp0, tmp1},"
"{%4}," // "{%4},"
"{%8}," // "{%8},"
"{%10, %11};\n" // "{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3}," // "{tmp2, tmp3},"
"{%5}," // "{%5},"
"{%8}," // "{%8},"
"{%12, %13};\n" // "{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1}," // "{%0, %1},"
"{%6}," // "{%6},"
"{%9}," // "{%9},"
"{tmp0, tmp1};\n" // "{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 " // "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3}," // "{%2, %3},"
"{%7}," // "{%7},"
"{%9}," // "{%9},"
"{tmp2, tmp3};\n" // "{tmp2, tmp3};\n"
"}\n" // "}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), // : "r"(a.x),
"r"(a.y), // "r"(a.y),
"r"(a.z), // "r"(a.z),
"r"(a.w), // "r"(a.w),
"r"(b.x), // "r"(b.x),
"r"(b.y), // "r"(b.y),
"r"(c.x), // "r"(c.x),
"r"(c.y), // "r"(c.y),
"r"(c.z), // "r"(c.z),
"r"(c.w), // "r"(c.w),
"n"(K / 2)); // "n"(K / 2));
#endif // #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d; return d;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment