Commit 8be63f64 authored by fengzch's avatar fengzch
Browse files

PTX 指令替换

parent d21ab0f5
...@@ -81,6 +81,53 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin ...@@ -81,6 +81,53 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__); // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
// 设备端的bfloat16到float转换函数
__device__ float bf16_to_float_device(uint16_t bf16) {
// 将bfloat16转为float:bf16左移16位作为float的高16位
uint32_t val = (uint32_t)bf16 << 16;
return __uint_as_float(val);
}
// 设备端的float到bfloat16转换函数
__device__ uint16_t float_to_bf16_device(float f) {
// 将float转为bfloat16:取float的高16位
uint32_t float_bits = __float_as_uint(f);
// 四舍五入处理
uint32_t rounding_bias = ((float_bits >> 16) & 1) + 0x7FFF;
return (uint16_t)((float_bits + rounding_bias) >> 16);
}
// C++实现的bfloat16x2 FMA函数
__device__ uint32_t fma_bf16x2_cpp(uint32_t a, uint32_t b, uint32_t c) {
// 解包a、b、c的高低位
uint16_t a_high = (uint16_t)(a >> 16);
uint16_t a_low = (uint16_t)(a & 0xFFFF);
uint16_t b_high = (uint16_t)(b >> 16);
uint16_t b_low = (uint16_t)(b & 0xFFFF);
uint16_t c_high = (uint16_t)(c >> 16);
uint16_t c_low = (uint16_t)(c & 0xFFFF);
// 将bfloat16转换为float进行计算
// 高位计算:(a_high * b_high) + c_high
float a_high_f = bf16_to_float_device(a_high);
float b_high_f = bf16_to_float_device(b_high);
float c_high_f = bf16_to_float_device(c_high);
float result_high_f = a_high_f * b_high_f + c_high_f;
uint16_t result_high = float_to_bf16_device(result_high_f);
// 低位计算:(a_low * b_low) + c_low
float a_low_f = bf16_to_float_device(a_low);
float b_low_f = bf16_to_float_device(b_low);
float c_low_f = bf16_to_float_device(c_low);
float result_low_f = a_low_f * b_low_f + c_low_f;
uint16_t result_low = float_to_bf16_device(result_low_f);
// 重新打包结果
return ((uint32_t)result_high << 16) | result_low;
}
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) { __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result); // dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
...@@ -103,22 +150,27 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so ...@@ -103,22 +150,27 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300 // // Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0]) // : "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300 // // Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1]) // : "=r"(h[1])
: "r"(i4s >> 4), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s >> 4), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300 // // Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2]) // : "=r"(h[2])
: "r"(i4s >> 8), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s >> 8), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300 // // Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3]) // : "=r"(h[3])
: "r"(i4s >> 12), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s >> 12), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
h[0] = ((i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM);
h[1] = (((i4s >> 4) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
h[2] = (((i4s >> 8) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
h[3] = (((i4s >> 12) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
// static constexpr uint32_t BF16_BIAS = 0xC308C308; // static constexpr uint32_t BF16_BIAS = 0xC308C308;
// This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7] // This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7]
...@@ -134,9 +186,8 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so ...@@ -134,9 +186,8 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS)); // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// // Convert elt_67 // // Convert elt_67
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS)); // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
h[0] = __hsub(h[0], __float2bfloat16_rn(128.0f)); h[0] = fma_bf16x2_cpp(h[0], BF16_ONE, BF16_BIAS);
h[1] = __hsub(h[1], __float2bfloat16_rn(128.0f)); h[1] = fma_bf16x2_cpp(h[1], BF16_ONE, BF16_BIAS);
h[2] = __hsub(h[2], __float2bfloat16_rn(128.0f)); h[2] = fma_bf16x2_cpp(h[2], BF16_ONE, BF16_BIAS);
h[3] = __hsub(h[3], __float2bfloat16_rn(128.0f)); h[3] = fma_bf16x2_cpp(h[3], BF16_ONE, BF16_BIAS);
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
...@@ -92,28 +92,26 @@ template<typename f16_t> ...@@ -92,28 +92,26 @@ template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) { __inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value, static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types."); "ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16" asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
// "{%0, %1, %2, %3}, [%4];" "{%0, %1, %2, %3}, [%4];"
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
// : "r"(addr)); : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
template<typename f16_t> template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) { __inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value, static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types."); "ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
// "{%0, %1, %2, %3}, [%4];" "{%0, %1, %2, %3}, [%4];"
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
// : "r"(addr)); : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) { __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) {
...@@ -383,10 +381,10 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A, ...@@ -383,10 +381,10 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
int M, int M,
int N, int N,
int K) { int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch(); // trap_unsupported_arch();
return; // return;
#endif // #endif
using f162_t = typename packed_as<f16_t, 2>::type; using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
......
...@@ -46,37 +46,73 @@ __device__ __forceinline__ static T load(const T *addr) { ...@@ -46,37 +46,73 @@ __device__ __forceinline__ static T load(const T *addr) {
return *addr; return *addr;
} }
// template<typename T>
// __device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
// if constexpr (sizeof(T) == 4) {
// uint32_t data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
// // "@loadpred ld.global.nc.b32 %0, [%1];"
// // "}"
// // : "=r"(data)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// if constexpr (sizeof(T) == 8) {
// uint2 data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
// // "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
// // "}"
// // : "=r"(data.x), "=r"(data.y)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// if constexpr (sizeof(T) == 16) {
// uint4 data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
// // "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
// // "}"
// // : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// T result;
// if (pred) {
// result = *addr;
// }
// return result;
// }
template<typename T> template<typename T>
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) { __device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
if constexpr (sizeof(T) == 4) { if constexpr (sizeof(T) == 4) {
uint32_t data; uint32_t data;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;" if (pred) {
// "@loadpred ld.global.nc.b32 %0, [%1];" const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
// "}" unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
// : "=r"(data) #pragma unroll
// : "l"(addr), "r"((int)pred)); for (int i = 0; i < 4; ++i) dst[i] = src[i];
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__); }
return *reinterpret_cast<T *>(&data); return *reinterpret_cast<T *>(&data);
} }
if constexpr (sizeof(T) == 8) { if constexpr (sizeof(T) == 8) {
uint2 data; uint2 data;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;" if (pred) {
// "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];" const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
// "}" unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
// : "=r"(data.x), "=r"(data.y) #pragma unroll
// : "l"(addr), "r"((int)pred)); for (int i = 0; i < 8; ++i) dst[i] = src[i];
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__); }
return *reinterpret_cast<T *>(&data); return *reinterpret_cast<T *>(&data);
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
uint4 data; uint4 data;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;" if (pred) {
// "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];" const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
// "}" unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
// : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) #pragma unroll
// : "l"(addr), "r"((int)pred)); for (int i = 0; i < 16; ++i) dst[i] = src[i];
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__); }
return *reinterpret_cast<T *>(&data); return *reinterpret_cast<T *>(&data);
} }
...@@ -92,21 +128,17 @@ __device__ __forceinline__ static void store(T *addr, T val) { ...@@ -92,21 +128,17 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if constexpr (shmem) { if constexpr (shmem) {
if constexpr (sizeof(T) == 8) { if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val); uint2 data = *reinterpret_cast<uint2 *>(&val);
// asm volatile( asm volatile(
// "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y)); "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val); uint4 data = *reinterpret_cast<uint4 *>(&val);
// asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)), asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
// "r"(data.x), "r"(data.x),
// "r"(data.y), "r"(data.y),
// "r"(data.z), "r"(data.z),
// "r"(data.w)); "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
*addr = val; *addr = val;
...@@ -115,17 +147,17 @@ __device__ __forceinline__ static void store(T *addr, T val) { ...@@ -115,17 +147,17 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if constexpr (sizeof(T) == 4) { if constexpr (sizeof(T) == 4) {
// __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val)); // __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
*reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val); *reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
return; return;
} }
if constexpr (sizeof(T) == 8) { if constexpr (sizeof(T) == 8) {
// __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val)); // __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
*reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val); *reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
return; return;
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
// __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val)); // __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
*reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val); *reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
return; return;
} }
*addr = val; *addr = val;
...@@ -135,39 +167,33 @@ template<typename T> ...@@ -135,39 +167,33 @@ template<typename T>
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) { __device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
if constexpr (sizeof(T) == 4) { if constexpr (sizeof(T) == 4) {
uint32_t data = *reinterpret_cast<uint32_t *>(&val); uint32_t data = *reinterpret_cast<uint32_t *>(&val);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.b32 [%1], %2;" "@storepred st.global.cg.b32 [%1], %2;"
// "}" ::"r"((int)pred), "}" ::"r"((int)pred),
// "l"(addr), "l"(addr),
// "r"(data)); "r"(data));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
if constexpr (sizeof(T) == 8) { if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val); uint2 data = *reinterpret_cast<uint2 *>(&val);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.v2.b32 [%1], {%2, %3};" "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
// "}" ::"r"((int)pred), "}" ::"r"((int)pred),
// "l"(addr), "l"(addr),
// "r"(data.x), "r"(data.x),
// "r"(data.y)); "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val); uint4 data = *reinterpret_cast<uint4 *>(&val);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};" "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
// "}" ::"r"((int)pred), "}" ::"r"((int)pred),
// "l"(addr), "l"(addr),
// "r"(data.x), "r"(data.x),
// "r"(data.y), "r"(data.y),
// "r"(data.z), "r"(data.z),
// "r"(data.w)); "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return; return;
} }
...@@ -229,11 +255,17 @@ template<> ...@@ -229,11 +255,17 @@ template<>
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) { __device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
int v1, v2; int v1, v2;
uint32_t result; uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); // asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__); v1 = __float2int_rn(value.x);
v2 = __float2int_rn(value.y);
int s1 = max(-8, min(7, v1));
int s2 = max(-8, min(7, v2));
unsigned int u1 = s1 & 0xF;
unsigned int u2 = s2 & 0xF;
result = (u2 << 4) | u1;
return result; return result;
} }
...@@ -241,11 +273,15 @@ template<> ...@@ -241,11 +273,15 @@ template<>
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) { __device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
int v1, v2; int v1, v2;
uint32_t result; uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); // asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
v1 = __float2int_rn(value.x);
v2 = __float2int_rn(value.y);
unsigned int u1 = static_cast<unsigned int>(max(0, min(15, v1)));
unsigned int u2 = static_cast<unsigned int>(max(0, min(15, v2)));
result = (u2 << 4) | u1;
return result; return result;
} }
...@@ -253,21 +289,29 @@ template<> ...@@ -253,21 +289,29 @@ template<>
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) { __device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
int v1, v2; int v1, v2;
uint32_t result; uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); // asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
v1 = __float2int_rn(value.x); // 等价于 roundf(value.x)
v2 = __float2int_rn(value.y);
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__); // 第二步:饱和处理到8位有符号范围 [-128, 127]
int s1 = max(-128, min(127, v1));
int s2 = max(-128, min(127, v2));
// 第三步:将有符号值转换为无符号位模式
// 使用位运算将有符号数转换为8位二进制补码表示
unsigned int u1 = s1 & 0xFF; // 只取低8位
unsigned int u2 = s2 & 0xFF;
result = (u2 << 8) | u1;
return result; return result;
} }
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) { __device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
uint32_t result; uint32_t result;
// asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }" asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
// : "=r"(result) : "=r"(result)
// : "f"(value.y), "f"(value.x)); : "f"(value.y), "f"(value.x));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return result; return result;
} }
...@@ -372,17 +416,15 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf ...@@ -372,17 +416,15 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf
}; };
__device__ __forceinline__ static void reduce_add(float *addr, float val) { __device__ __forceinline__ static void reduce_add(float *addr, float val) {
// asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val)); asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) { __device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;" "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
// "}" ::"r"((int)pred), "}" ::"r"((int)pred),
// "l"(addr), "l"(addr),
// "f"(val)); "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
} }
template<int cnt, typename F> template<int cnt, typename F>
...@@ -394,13 +436,15 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) { ...@@ -394,13 +436,15 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) {
// int2float is slow on sm_80 and before // int2float is slow on sm_80 and before
// val in [-4194304, 4194303] // val in [-4194304, 4194303]
__device__ __forceinline__ static float int2float_fast(int val) { __device__ __forceinline__ static float int2float_fast(int val) {
float fval; // float fval;
// fval = (val & 0x7FFFFF) ^ 0x4B400000 // // fval = (val & 0x7FFFFF) ^ 0x4B400000
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;" // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=f"(fval) // : "=f"(fval)
// : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA)); // : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__); unsigned int temp = (val & 0x7FFFFF) ^ 0x4B400000;
return fval - 12582912.0f; float result;
memcpy(&result, &temp, sizeof(float));
return result - 12582912.0f;
} }
template<typename To, typename From> template<typename To, typename From>
...@@ -416,13 +460,12 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) { ...@@ -416,13 +460,12 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
uint32_t ival; uint32_t ival;
uint32_t hval; uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo; // ival.lo = x.lo; ival.hi = y.lo;
// asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
ival = ival >> 4; ival = ival >> 4;
// (val & 0x03FF03FF) ^ 0x76007600 // (val & 0x03FF03FF) ^ 0x76007600
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;" asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval) : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA)); : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
} }
// val in [-4096, 4095], steps of 8, round to nearest // val in [-4096, 4095], steps of 8, round to nearest
...@@ -436,12 +479,11 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) { ...@@ -436,12 +479,11 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
uint32_t hval; uint32_t hval;
// ival.lo = x.hi; ival.hi = y.hi; // ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them // <=> divide x and y by 65536 and pack them
// asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632)); asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200 // (val & 0x03FF03FF) ^ 0x72007200
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;" asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval) : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA)); : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
} }
// val in [-512, 511] // val in [-512, 511]
...@@ -450,12 +492,11 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) { ...@@ -450,12 +492,11 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
uint32_t hval; uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo; // ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them // <=> divide x and y by 65536 and pack them
//asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600 // (val & 0x03FF03FF) ^ 0x66006600
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;" asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval) : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA)); : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment