"server/text_generation_server/models/seq2seq_lm.py" did not exist on "2ad895a6cc530474cae7e24ace1e463018172d0e"
Commit 8be63f64 authored by fengzch's avatar fengzch
Browse files

PTX 指令替换

parent d21ab0f5
......@@ -81,6 +81,53 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
// 设备端的bfloat16到float转换函数
__device__ float bf16_to_float_device(uint16_t bf16) {
// 将bfloat16转为float:bf16左移16位作为float的高16位
uint32_t val = (uint32_t)bf16 << 16;
return __uint_as_float(val);
}
// 设备端的float到bfloat16转换函数
__device__ uint16_t float_to_bf16_device(float f) {
// 将float转为bfloat16:取float的高16位
uint32_t float_bits = __float_as_uint(f);
// 四舍五入处理
uint32_t rounding_bias = ((float_bits >> 16) & 1) + 0x7FFF;
return (uint16_t)((float_bits + rounding_bias) >> 16);
}
// C++实现的bfloat16x2 FMA函数
__device__ uint32_t fma_bf16x2_cpp(uint32_t a, uint32_t b, uint32_t c) {
// 解包a、b、c的高低位
uint16_t a_high = (uint16_t)(a >> 16);
uint16_t a_low = (uint16_t)(a & 0xFFFF);
uint16_t b_high = (uint16_t)(b >> 16);
uint16_t b_low = (uint16_t)(b & 0xFFFF);
uint16_t c_high = (uint16_t)(c >> 16);
uint16_t c_low = (uint16_t)(c & 0xFFFF);
// 将bfloat16转换为float进行计算
// 高位计算:(a_high * b_high) + c_high
float a_high_f = bf16_to_float_device(a_high);
float b_high_f = bf16_to_float_device(b_high);
float c_high_f = bf16_to_float_device(c_high);
float result_high_f = a_high_f * b_high_f + c_high_f;
uint16_t result_high = float_to_bf16_device(result_high_f);
// 低位计算:(a_low * b_low) + c_low
float a_low_f = bf16_to_float_device(a_low);
float b_low_f = bf16_to_float_device(b_low);
float c_low_f = bf16_to_float_device(c_low);
float result_low_f = a_low_f * b_low_f + c_low_f;
uint16_t result_low = float_to_bf16_device(result_low_f);
// 重新打包结果
return ((uint32_t)result_high << 16) | result_low;
}
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
......@@ -103,22 +150,27 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s >> 4), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(i4s >> 8), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(i4s >> 12), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[0])
// : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[1])
// : "r"(i4s >> 4), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[2])
// : "r"(i4s >> 8), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[3])
// : "r"(i4s >> 12), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
h[0] = ((i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM);
h[1] = (((i4s >> 4) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
h[2] = (((i4s >> 8) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
h[3] = (((i4s >> 12) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
// static constexpr uint32_t BF16_BIAS = 0xC308C308;
// This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7]
......@@ -134,9 +186,8 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// // Convert elt_67
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
h[0] = __hsub(h[0], __float2bfloat16_rn(128.0f));
h[1] = __hsub(h[1], __float2bfloat16_rn(128.0f));
h[2] = __hsub(h[2], __float2bfloat16_rn(128.0f));
h[3] = __hsub(h[3], __float2bfloat16_rn(128.0f));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
h[0] = fma_bf16x2_cpp(h[0], BF16_ONE, BF16_BIAS);
h[1] = fma_bf16x2_cpp(h[1], BF16_ONE, BF16_BIAS);
h[2] = fma_bf16x2_cpp(h[2], BF16_ONE, BF16_BIAS);
h[3] = fma_bf16x2_cpp(h[3], BF16_ONE, BF16_BIAS);
}
......@@ -92,28 +92,26 @@ template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
// : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
// : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) {
......@@ -383,10 +381,10 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
#endif
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// trap_unsupported_arch();
// return;
// #endif
using f162_t = typename packed_as<f16_t, 2>::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
......
......@@ -46,37 +46,73 @@ __device__ __forceinline__ static T load(const T *addr) {
return *addr;
}
// template<typename T>
// __device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
// if constexpr (sizeof(T) == 4) {
// uint32_t data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
// // "@loadpred ld.global.nc.b32 %0, [%1];"
// // "}"
// // : "=r"(data)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// if constexpr (sizeof(T) == 8) {
// uint2 data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
// // "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
// // "}"
// // : "=r"(data.x), "=r"(data.y)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// if constexpr (sizeof(T) == 16) {
// uint4 data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
// // "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
// // "}"
// // : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// T result;
// if (pred) {
// result = *addr;
// }
// return result;
// }
template<typename T>
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
// "@loadpred ld.global.nc.b32 %0, [%1];"
// "}"
// : "=r"(data)
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
if (pred) {
const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
#pragma unroll
for (int i = 0; i < 4; ++i) dst[i] = src[i];
}
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 8) {
uint2 data;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
// "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
// "}"
// : "=r"(data.x), "=r"(data.y)
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
if (pred) {
const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
#pragma unroll
for (int i = 0; i < 8; ++i) dst[i] = src[i];
}
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 16) {
uint4 data;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
// "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
// "}"
// : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
if (pred) {
const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
#pragma unroll
for (int i = 0; i < 16; ++i) dst[i] = src[i];
}
return *reinterpret_cast<T *>(&data);
}
......@@ -92,21 +128,17 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if constexpr (shmem) {
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
// asm volatile(
// "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile(
"st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
// asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
// "r"(data.x),
// "r"(data.y),
// "r"(data.z),
// "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
"r"(data.x),
"r"(data.y),
"r"(data.z),
"r"(data.w));
return;
}
*addr = val;
......@@ -115,17 +147,17 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if constexpr (sizeof(T) == 4) {
// __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
*reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
*reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
return;
}
if constexpr (sizeof(T) == 8) {
// __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
*reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
*reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
return;
}
if constexpr (sizeof(T) == 16) {
// __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
*reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
*reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
return;
}
*addr = val;
......@@ -135,39 +167,33 @@ template<typename T>
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data = *reinterpret_cast<uint32_t *>(&val);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.b32 [%1], %2;"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data));
return;
}
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data.x),
// "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data.x),
"r"(data.y));
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data.x),
// "r"(data.y),
// "r"(data.z),
// "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data.x),
"r"(data.y),
"r"(data.z),
"r"(data.w));
return;
}
......@@ -229,11 +255,17 @@ template<>
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
// asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
v1 = __float2int_rn(value.x);
v2 = __float2int_rn(value.y);
int s1 = max(-8, min(7, v1));
int s2 = max(-8, min(7, v2));
unsigned int u1 = s1 & 0xF;
unsigned int u2 = s2 & 0xF;
result = (u2 << 4) | u1;
return result;
}
......@@ -241,11 +273,15 @@ template<>
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
// asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
v1 = __float2int_rn(value.x);
v2 = __float2int_rn(value.y);
unsigned int u1 = static_cast<unsigned int>(max(0, min(15, v1)));
unsigned int u2 = static_cast<unsigned int>(max(0, min(15, v2)));
result = (u2 << 4) | u1;
return result;
}
......@@ -253,21 +289,29 @@ template<>
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
// asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
v1 = __float2int_rn(value.x); // 等价于 roundf(value.x)
v2 = __float2int_rn(value.y);
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
// 第二步:饱和处理到8位有符号范围 [-128, 127]
int s1 = max(-128, min(127, v1));
int s2 = max(-128, min(127, v2));
// 第三步:将有符号值转换为无符号位模式
// 使用位运算将有符号数转换为8位二进制补码表示
unsigned int u1 = s1 & 0xFF; // 只取低8位
unsigned int u2 = s2 & 0xFF;
result = (u2 << 8) | u1;
return result;
}
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
uint32_t result;
// asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
// : "=r"(result)
// : "f"(value.y), "f"(value.x));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
: "=r"(result)
: "f"(value.y), "f"(value.x));
return result;
}
......@@ -372,17 +416,15 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf
};
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
// asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
}
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
// "}" ::"r"((int)pred),
// "l"(addr),
// "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}" ::"r"((int)pred),
"l"(addr),
"f"(val));
}
template<int cnt, typename F>
......@@ -394,13 +436,15 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) {
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__ __forceinline__ static float int2float_fast(int val) {
float fval;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
// float fval;
// // fval = (val & 0x7FFFFF) ^ 0x4B400000
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=f"(fval)
// : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return fval - 12582912.0f;
unsigned int temp = (val & 0x7FFFFF) ^ 0x4B400000;
float result;
memcpy(&result, &temp, sizeof(float));
return result - 12582912.0f;
}
template<typename To, typename From>
......@@ -416,13 +460,12 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
uint32_t ival;
uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo;
// asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
ival = ival >> 4;
// (val & 0x03FF03FF) ^ 0x76007600
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
}
// val in [-4096, 4095], steps of 8, round to nearest
......@@ -436,12 +479,11 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
uint32_t hval;
// ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them
// asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
}
// val in [-512, 511]
......@@ -450,12 +492,11 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them
//asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment