Commit d21ab0f5 authored by fengzch's avatar fengzch
Browse files

fix: use rocm

parent 181f4e43
......@@ -4,5 +4,6 @@ source /usr/local/bin/fastpt -T
export CPLUS_INCLUDE_PATH=/opt/dtk/roctracer/include:$CPLUS_INCLUDE_PATH
export AMDGPU_TARGETS="gfx906;gfx926;gfx928;gfx936"
export FASTPT_USE_ASM=1
CXX=hipcc CC=hipcc python setup.py bdist_wheel
......@@ -308,7 +308,7 @@ def check_hardware_compatibility(quantization_config: dict, device: str | torch.
if sm == "120": # you can only use the fp4 models
if quantization_config["weight"]["dtype"] != "fp4_e2m1_all":
raise ValueError('Please use "fp4" quantization for Blackwell GPUs. ')
elif sm in ["75", "80", "86", "89"]:
elif sm in ["75", "80", "86", "89", "92", "93"]:
if quantization_config["weight"]["dtype"] != "int4":
raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs. ')
else:
......
......@@ -12,6 +12,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cstdint>
__forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
......@@ -67,12 +68,17 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
// Finally, we construct the output numbers.
// Convert elt_01
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
h[0] = __hsub(h[0], __float2half(1024.0f));
// Convert elt_23
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h[1] = __hfma(h[1], __float2half(0.0625f), __float2half(-64.0f));
// Convert elt_45
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h[2] = __hsub(h[2], __float2half(1024.0f));
// Convert elt_67
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h[3] = __hfma(h[3], __float2half(0.0625f), __float2half(-64.0f));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
......@@ -121,11 +127,16 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_23
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_45
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(BF16_ONE), "r"(BF16_BIAS));
// // Convert elt_23
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(BF16_ONE), "r"(BF16_BIAS));
// // Convert elt_45
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// // Convert elt_67
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
h[0] = __hsub(h[0], __float2bfloat16_rn(128.0f));
h[1] = __hsub(h[1], __float2bfloat16_rn(128.0f));
h[2] = __hsub(h[2], __float2bfloat16_rn(128.0f));
h[3] = __hsub(h[3], __float2bfloat16_rn(128.0f));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
......@@ -81,10 +81,10 @@ __device__ void sync_slice(int slice_id) {
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
// asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
// : "=r"(smem_int_ptr)
// : "l"(ptr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return smem_int_ptr;
}
......@@ -92,38 +92,41 @@ template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
// : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
// : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) {
const int cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
// asm volatile("{"
// " .reg .pred p;"
// " setp.ne.b32 p, %0, 0;"
// " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
// "}" ::"r"((int)mask),
// "r"(smem_int_ptr),
// "l"(src),
// "n"(cp_size));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
template<typename f16_t>
......@@ -131,39 +134,41 @@ __device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16
template<>
__device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp, half *B_shared_warp) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3]));
// asm volatile(
// "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
// "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
// : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
// : "r"(((unsigned *)A_shared_warp)[0]),
// "r"(((unsigned *)A_shared_warp)[1]),
// "r"(((unsigned *)A_shared_warp)[2]),
// "r"(((unsigned *)A_shared_warp)[3]),
// "r"(((unsigned *)B_shared_warp)[0]),
// "r"(((unsigned *)B_shared_warp)[1]),
// "f"(((float *)C_warp)[0]),
// "f"(((float *)C_warp)[1]),
// "f"(((float *)C_warp)[2]),
// "f"(((float *)C_warp)[3]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
template<>
__device__ __inline__ void
mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]),
"r"(((unsigned *)A_shared_warp)[1]),
"r"(((unsigned *)A_shared_warp)[2]),
"r"(((unsigned *)A_shared_warp)[3]),
"r"(((unsigned *)B_shared_warp)[0]),
"r"(((unsigned *)B_shared_warp)[1]),
"f"(((float *)C_warp)[0]),
"f"(((float *)C_warp)[1]),
"f"(((float *)C_warp)[2]),
"f"(((float *)C_warp)[3]));
// asm volatile(
// "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
// "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
// : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
// : "r"(((unsigned *)A_shared_warp)[0]),
// "r"(((unsigned *)A_shared_warp)[1]),
// "r"(((unsigned *)A_shared_warp)[2]),
// "r"(((unsigned *)A_shared_warp)[3]),
// "r"(((unsigned *)B_shared_warp)[0]),
// "r"(((unsigned *)B_shared_warp)[1]),
// "f"(((float *)C_warp)[0]),
// "f"(((float *)C_warp)[1]),
// "f"(((float *)C_warp)[2]),
// "f"(((float *)C_warp)[3]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
template<typename f16_t, int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
......@@ -944,10 +949,11 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
#endif
//#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// trap_unsupported_arch();
// return;
//#endif
// printf("LOG(INFO) %s: %d %s\n", __FILE__, __LINE__, __func__);
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
......
......@@ -171,7 +171,7 @@ inline __device__ T ldg(const T *val) {
#define float22bf162 __float22bfloat162_rn
#define bf162bf162 __bfloat162bfloat162
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
......@@ -203,7 +203,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
return val[0];
#else
return __ldg(val);
......@@ -212,7 +212,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) {
template<>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16 *val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
return val[0];
#else
return __ldg(val);
......
......@@ -191,14 +191,26 @@ public:
// set nan values to -inf
__device__ __forceinline__ static half2_t fix_nan(half2_t input) {
static constexpr float neginf = -std::numeric_limits<float>::infinity();
// static constexpr float neginf = -std::numeric_limits<float>::infinity();
/**
* In accordance to the IEEE-754R standard,
* if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN,
* but not the other,
* the result is the non-NaN parameter.
*/
return __hmax2(input, half2_t(neginf, neginf));
// return __hmax2(input, half2_t(neginf, neginf));
half_t lo = __low2half(input);
half_t hi = __high2half(input);
// Step 2: Convert to float to use isnan (HIP supports __hisnan)
// Option A: Use __hisnan if available (preferred)
half_t neg_inf = __float2half(-std::numeric_limits<float>::infinity());
half_t out_lo = __hisnan(lo) ? neg_inf : lo;
half_t out_hi = __hisnan(hi) ? neg_inf : hi;
// Step 3: Pack back into half2_t
return __halves2half2(out_lo, out_hi);
}
__device__ __forceinline__ static float fix_nan(float input) {
......@@ -511,7 +523,8 @@ public:
if (alwaysfalse) {
dummy = clock();
}
// asm volatile ("membar.cta;");
asm volatile ("membar.cta;");
}
}
......
......@@ -50,29 +50,33 @@ template<typename T>
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data;
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}"
: "=r"(data)
: "l"(addr), "r"((int)pred));
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
// "@loadpred ld.global.nc.b32 %0, [%1];"
// "}"
// : "=r"(data)
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 8) {
uint2 data;
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}"
: "=r"(data.x), "=r"(data.y)
: "l"(addr), "r"((int)pred));
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
// "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
// "}"
// : "=r"(data.x), "=r"(data.y)
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 16) {
uint4 data;
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(addr), "r"((int)pred));
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
// "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
// "}"
// : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return *reinterpret_cast<T *>(&data);
}
......@@ -88,17 +92,21 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if constexpr (shmem) {
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile(
"st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
// asm volatile(
// "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
"r"(data.x),
"r"(data.y),
"r"(data.z),
"r"(data.w));
// asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
// "r"(data.x),
// "r"(data.y),
// "r"(data.z),
// "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return;
}
*addr = val;
......@@ -127,33 +135,39 @@ template<typename T>
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data = *reinterpret_cast<uint32_t *>(&val);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data));
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.b32 [%1], %2;"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return;
}
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data.x),
"r"(data.y));
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data.x),
// "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data.x),
"r"(data.y),
"r"(data.z),
"r"(data.w));
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data.x),
// "r"(data.y),
// "r"(data.z),
// "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return;
}
......@@ -194,14 +208,16 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
: "l"((ptr))); // limengmeng
: "l"((ptr)));
}
template<typename T>
__device__ __forceinline__ static T movmatrix(T x) {
asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(*reinterpret_cast<uint32_t *>(&x))
: "r"(*reinterpret_cast<uint32_t *>(&x)));
// asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
// : "=r"(*reinterpret_cast<uint32_t *>(&x))
// : "r"(*reinterpret_cast<uint32_t *>(&x)));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return x;
}
......@@ -215,7 +231,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result;
}
......@@ -225,7 +243,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result;
}
......@@ -235,22 +255,27 @@ __device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result;
}
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
uint32_t result;
asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
: "=r"(result)
: "f"(value.y), "f"(value.x));
// asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
// : "=r"(result)
// : "f"(value.y), "f"(value.x));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result;
}
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
uint16_t lo, hi;
asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
// asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
// asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return uint32_t(lo) | (uint32_t(hi) << 16);
}
......@@ -347,15 +372,17 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf
};
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
// asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}" ::"r"((int)pred),
"l"(addr),
"f"(val));
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
// "}" ::"r"((int)pred),
// "l"(addr),
// "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
template<int cnt, typename F>
......@@ -369,9 +396,10 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) {
__device__ __forceinline__ static float int2float_fast(int val) {
float fval;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=f"(fval)
: "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=f"(fval)
// : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return fval - 12582912.0f;
}
......@@ -388,12 +416,13 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
uint32_t ival;
uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo;
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
ival = ival >> 4;
// (val & 0x03FF03FF) ^ 0x76007600
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
}
// val in [-4096, 4095], steps of 8, round to nearest
......@@ -407,11 +436,12 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
uint32_t hval;
// ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
}
// val in [-512, 511]
......@@ -420,11 +450,12 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
//asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
}
......
......@@ -247,7 +247,7 @@ public:
// "r"(wmscale),
// "n"(0),
// "h"((short)(idb * 2 + 1)));
std::cout << __func__ << "mma_fp4 is not implemented for HIP yet[asm error!!!]" << std::endl;
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return out;
}
......@@ -334,7 +334,8 @@ public:
dummy = clock();
}
// asm volatile ("membar.cta;");
asm volatile ("membar.cta;");
}
}
......@@ -916,7 +917,9 @@ public:
}
// #endif
// asm volatile ("membar.cta;");
asm volatile ("membar.cta;");
}
}
......
......@@ -10,40 +10,42 @@ public:
__device__ __forceinline__ static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
// packed_psum_t psum;
asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
: "r"(act.x),
"r"(act.y),
"r"(act.z),
"r"(act.w),
"r"(wgt.x),
"r"(wgt.y),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[0]),
"r"(psum.data[1]),
"r"(psum.data[2]),
"r"(psum.data[3]));
asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
: "r"(act.x),
"r"(act.y),
"r"(act.z),
"r"(act.w),
"r"(wgt.z),
"r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[4]),
"r"(psum.data[5]),
"r"(psum.data[6]),
"r"(psum.data[7]));
// asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// : "=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
// : "r"(act.x),
// "r"(act.y),
// "r"(act.z),
// "r"(act.w),
// "r"(wgt.x),
// "r"(wgt.y),
// // "r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]),
// "r"(psum.data[1]),
// "r"(psum.data[2]),
// "r"(psum.data[3]));
// asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// : "=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
// : "r"(act.x),
// "r"(act.y),
// "r"(act.z),
// "r"(act.w),
// "r"(wgt.z),
// "r"(wgt.w),
// // "r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]),
// "r"(psum.data[5]),
// "r"(psum.data[6]),
// "r"(psum.data[7]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return psum;
}
......@@ -418,7 +420,8 @@ public:
// dummy = clock();
// }
// asm volatile ("membar.cta;");
asm volatile ("membar.cta;");
}
}
......
......@@ -110,65 +110,65 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b,
uint4 d;
static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K),
"C"(AType::value),
"C"(BType::value));
#else
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K / 2),
"C"(AType::value),
"C"(BType::value));
#endif
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x),
// "r"(a.y),
// "r"(a.z),
// "r"(a.w),
// "r"(b.x),
// "r"(b.y),
// "r"(c.x),
// "r"(c.y),
// "r"(c.z),
// "r"(c.w),
// "n"(K),
// "C"(AType::value),
// "C"(BType::value));
// #else
// asm volatile("{"
// ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
// "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
// "{tmp0, tmp1},"
// "{%4},"
// "{%8},"
// "{%10, %11};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
// "{tmp2, tmp3},"
// "{%5},"
// "{%8},"
// "{%12, %13};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
// "{%0, %1},"
// "{%6},"
// "{%9},"
// "{tmp0, tmp1};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
// "{%2, %3},"
// "{%7},"
// "{%9},"
// "{tmp2, tmp3};\n"
// "}\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x),
// "r"(a.y),
// "r"(a.z),
// "r"(a.w),
// "r"(b.x),
// "r"(b.y),
// "r"(c.x),
// "r"(c.y),
// "r"(c.z),
// "r"(c.w),
// "n"(K / 2),
// "C"(AType::value),
// "C"(BType::value));
// #endif
return d;
}
......
......@@ -36,31 +36,32 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
: "=r"(d.x), "=r"(d.y)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#else
asm volatile("{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
: "=r"(d.x), "=r"(d.y)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#endif
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
// "{%0, %1},"
// "{%2, %3, %4, %5},"
// "{%6, %7},"
// "{%8, %9};\n"
// : "=r"(d.x), "=r"(d.y)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
// #else
// asm volatile("{"
// ".reg .b32 tmp0, tmp1;"
// "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
// "{tmp0, tmp1},"
// "{%2, %3},"
// "{%6},"
// "{%8, %9};\n"
// "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
// "{%0, %1},"
// "{%4, %5},"
// "{%7},"
// "{tmp0, tmp1};"
// "}\n"
// : "=r"(d.x), "=r"(d.y)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
// #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d;
}
......@@ -71,13 +72,14 @@ __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2
template<>
__device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
uint4 d;
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
// asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d;
}
#endif
......@@ -85,31 +87,32 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2
template<>
__device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#else
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#endif
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
// #else
// asm volatile("{"
// ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
// "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
// "{tmp0, tmp1, tmp2, tmp3},"
// "{%4, %5},"
// "{%8},"
// "{%10, %11, %12, %13};\n"
// "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
// "{%0, %1, %2, %3},"
// "{%6, %7},"
// "{%9},"
// "{tmp0, tmp1, tmp2, tmp3};"
// "}\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
// #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d;
}
......@@ -121,7 +124,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
uint4 d;
static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile(
// "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
// "{%0, %1, %2, %3},"
......@@ -130,43 +133,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K / 2));
#endif
// #else
// asm volatile("{"
// ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
// "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
// "{tmp0, tmp1},"
// "{%4},"
// "{%8},"
// "{%10, %11};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
// "{tmp2, tmp3},"
// "{%5},"
// "{%8},"
// "{%12, %13};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
// "{%0, %1},"
// "{%6},"
// "{%9},"
// "{tmp0, tmp1};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
// "{%2, %3},"
// "{%7},"
// "{%9},"
// "{tmp2, tmp3};\n"
// "}\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x),
// "r"(a.y),
// "r"(a.z),
// "r"(a.w),
// "r"(b.x),
// "r"(b.y),
// "r"(c.x),
// "r"(c.y),
// "r"(c.z),
// "r"(c.w),
// "n"(K / 2));
// #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d;
}
......@@ -175,7 +179,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
uint4 d;
static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile(
// "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
// "{%0, %1, %2, %3},"
......@@ -184,43 +188,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K / 2));
#endif
// #else
// asm volatile("{"
// ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
// "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
// "{tmp0, tmp1},"
// "{%4},"
// "{%8},"
// "{%10, %11};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
// "{tmp2, tmp3},"
// "{%5},"
// "{%8},"
// "{%12, %13};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
// "{%0, %1},"
// "{%6},"
// "{%9},"
// "{tmp0, tmp1};\n"
// "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
// "{%2, %3},"
// "{%7},"
// "{%9},"
// "{tmp2, tmp3};\n"
// "}\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x),
// "r"(a.y),
// "r"(a.z),
// "r"(a.w),
// "r"(b.x),
// "r"(b.y),
// "r"(c.x),
// "r"(c.y),
// "r"(c.z),
// "r"(c.w),
// "n"(K / 2));
// #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return d;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment