Commit fcc551cb authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] fix compilation error

WHO THE HELL INVENTED ADL?
parent 9c92fe81
...@@ -126,12 +126,12 @@ public: ...@@ -126,12 +126,12 @@ public:
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
results[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1])); results[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
} }
return bit_cast<packed_fpsum_t>(results); return kernels::bit_cast<packed_fpsum_t>(results);
} }
__device__ __forceinline__ __device__ __forceinline__
static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) { static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
auto arr = bit_cast<std::array<half2_t, 4>>(input); auto arr = kernels::bit_cast<std::array<half2_t, 4>>(input);
packed_f32psum_t results; packed_f32psum_t results;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
float2 tmp = half22float2(arr[i]); float2 tmp = half22float2(arr[i]);
...@@ -214,10 +214,10 @@ public: ...@@ -214,10 +214,10 @@ public:
__device__ __forceinline__ __device__ __forceinline__
static packed_fpsum_t fix_nan(packed_fpsum_t input) { static packed_fpsum_t fix_nan(packed_fpsum_t input) {
input.x = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.x))); input.x = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.x)));
input.y = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.y))); input.y = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.y)));
input.z = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.z))); input.z = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.z)));
input.w = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.w))); input.w = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.w)));
return input; return input;
} }
......
...@@ -206,21 +206,21 @@ public: ...@@ -206,21 +206,21 @@ public:
static constexpr bool is_bf16 = std::is_same_v<half_t, __nv_bfloat16>; static constexpr bool is_bf16 = std::is_same_v<half_t, __nv_bfloat16>;
uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>( uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>(
bit_cast<uint4>(a), kernels::bit_cast<uint4>(a),
bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])), kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])),
bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3]))); kernels::bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3])));
uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>( uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>(
bit_cast<uint4>(a), kernels::bit_cast<uint4>(a),
bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])), kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])),
bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7]))); kernels::bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7])));
psum.data[0] = bit_cast<float>(out1.x); psum.data[0] = kernels::bit_cast<float>(out1.x);
psum.data[1] = bit_cast<float>(out1.y); psum.data[1] = kernels::bit_cast<float>(out1.y);
psum.data[2] = bit_cast<float>(out1.z); psum.data[2] = kernels::bit_cast<float>(out1.z);
psum.data[3] = bit_cast<float>(out1.w); psum.data[3] = kernels::bit_cast<float>(out1.w);
psum.data[4] = bit_cast<float>(out2.x); psum.data[4] = kernels::bit_cast<float>(out2.x);
psum.data[5] = bit_cast<float>(out2.y); psum.data[5] = kernels::bit_cast<float>(out2.y);
psum.data[6] = bit_cast<float>(out2.z); psum.data[6] = kernels::bit_cast<float>(out2.z);
psum.data[7] = bit_cast<float>(out2.w); psum.data[7] = kernels::bit_cast<float>(out2.w);
return psum; return psum;
} }
......
...@@ -573,7 +573,7 @@ static half2 int2half2_fast_8192(int x, int y) { ...@@ -573,7 +573,7 @@ static half2 int2half2_fast_8192(int x, int y) {
ival = ival >> 4; ival = ival >> 4;
// (val & 0x03FF03FF) ^ 0x76007600 // (val & 0x03FF03FF) ^ 0x76007600
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA)); asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-24576.0f, -24576.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
} }
// val in [-4096, 4095], steps of 8, round to nearest // val in [-4096, 4095], steps of 8, round to nearest
__device__ __forceinline__ __device__ __forceinline__
...@@ -590,7 +590,7 @@ static half2 int2half2_fast_4096_rn(int x, int y) { ...@@ -590,7 +590,7 @@ static half2 int2half2_fast_4096_rn(int x, int y) {
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632)); asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200 // (val & 0x03FF03FF) ^ 0x72007200
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA)); asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-12288.0f, -12288.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
} }
// val in [-512, 511] // val in [-512, 511]
__device__ __forceinline__ __device__ __forceinline__
...@@ -602,7 +602,7 @@ static half2 int2half2_fast_512(int x, int y) { ...@@ -602,7 +602,7 @@ static half2 int2half2_fast_512(int x, int y) {
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600 // (val & 0x03FF03FF) ^ 0x66006600
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA)); asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-1536.0f, -1536.0f)); return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -1674,7 +1674,7 @@ public: ...@@ -1674,7 +1674,7 @@ public:
const int col = n * INSN_N + laneId / 16 * 8; // lane 0-15: n*16+0, lane 16-31: n*16+8 const int col = n * INSN_N + laneId / 16 * 8; // lane 0-15: n*16+0, lane 16-31: n*16+8
uint4 tmp; uint4 tmp;
ldmatrix(shmem + col, tmp); ldmatrix(shmem + col, tmp);
return bit_cast<packed_fpsum_t>(tmp); return kernels::bit_cast<packed_fpsum_t>(tmp);
} }
__device__ __forceinline__ __device__ __forceinline__
...@@ -1813,30 +1813,30 @@ public: ...@@ -1813,30 +1813,30 @@ public:
__device__ __forceinline__ __device__ __forceinline__
static packed_qkv_t pack_q(packed_fpsum_t input) { static packed_qkv_t pack_q(packed_fpsum_t input) {
packed_qkv_t output; packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(input.data[0])); output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = bit_cast<int>(convert_half2(input.data[1])); output.y = kernels::bit_cast<int>(convert_half2(input.data[1]));
output.z = bit_cast<int>(convert_half2(input.data[2])); output.z = kernels::bit_cast<int>(convert_half2(input.data[2]));
output.w = bit_cast<int>(convert_half2(input.data[3])); output.w = kernels::bit_cast<int>(convert_half2(input.data[3]));
return output; return output;
} }
__device__ __forceinline__ __device__ __forceinline__
static packed_qkv_t pack_k(packed_fpsum_t input) { static packed_qkv_t pack_k(packed_fpsum_t input) {
packed_qkv_t output; packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(input.data[0])); output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = bit_cast<int>(convert_half2(input.data[2])); output.y = kernels::bit_cast<int>(convert_half2(input.data[2]));
output.z = bit_cast<int>(convert_half2(input.data[1])); output.z = kernels::bit_cast<int>(convert_half2(input.data[1]));
output.w = bit_cast<int>(convert_half2(input.data[3])); output.w = kernels::bit_cast<int>(convert_half2(input.data[3]));
return output; return output;
} }
__device__ __forceinline__ __device__ __forceinline__
static packed_qkv_t pack_v(packed_fpsum_t input) { static packed_qkv_t pack_v(packed_fpsum_t input) {
packed_qkv_t output; packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(movmatrix(input.data[0]))); output.x = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[0])));
output.y = bit_cast<int>(convert_half2(movmatrix(input.data[1]))); output.y = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[1])));
output.z = bit_cast<int>(convert_half2(movmatrix(input.data[2]))); output.z = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[2])));
output.w = bit_cast<int>(convert_half2(movmatrix(input.data[3]))); output.w = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[3])));
return output; return output;
} }
...@@ -1867,7 +1867,7 @@ public: ...@@ -1867,7 +1867,7 @@ public:
unrolled_loop<WARP_M_TILES>([&]<int m>() ALWAYSINLINE { unrolled_loop<WARP_M_TILES>([&]<int m>() ALWAYSINLINE {
unrolled_loop<WARP_N_TILES>([&]<int n>() ALWAYSINLINE { unrolled_loop<WARP_N_TILES>([&]<int n>() ALWAYSINLINE {
packed_qkv_t pack = funcPack(fpsum[m * WARP_N_TILES + n]); packed_qkv_t pack = funcPack(fpsum[m * WARP_N_TILES + n]);
mask(pack, bit_cast<uint32_t>(maskVal), m, maxRows - warpId * WARP_M); mask(pack, kernels::bit_cast<uint32_t>(maskVal), m, maxRows - warpId * WARP_M);
store(&ptrlane[(m * WARP_N_TILES + n) * WARP_SIZE], pack); store(&ptrlane[(m * WARP_N_TILES + n) * WARP_SIZE], pack);
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment