Commit fcc551cb authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] fix compilation error

WHO THE HELL INVENTED ADL?
parent 9c92fe81
......@@ -126,12 +126,12 @@ public:
for (int i = 0; i < 4; i++) {
results[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
}
return bit_cast<packed_fpsum_t>(results);
return kernels::bit_cast<packed_fpsum_t>(results);
}
__device__ __forceinline__
static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
auto arr = bit_cast<std::array<half2_t, 4>>(input);
auto arr = kernels::bit_cast<std::array<half2_t, 4>>(input);
packed_f32psum_t results;
for (int i = 0; i < 4; i++) {
float2 tmp = half22float2(arr[i]);
......@@ -214,10 +214,10 @@ public:
__device__ __forceinline__
static packed_fpsum_t fix_nan(packed_fpsum_t input) {
input.x = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.x)));
input.y = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.y)));
input.z = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.z)));
input.w = bit_cast<int>(fix_nan(bit_cast<half2_t>(input.w)));
input.x = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.x)));
input.y = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.y)));
input.z = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.z)));
input.w = kernels::bit_cast<int>(fix_nan(kernels::bit_cast<half2_t>(input.w)));
return input;
}
......
......@@ -206,21 +206,21 @@ public:
static constexpr bool is_bf16 = std::is_same_v<half_t, __nv_bfloat16>;
uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>(
bit_cast<uint4>(a),
bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])),
bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3])));
kernels::bit_cast<uint4>(a),
kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])),
kernels::bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3])));
uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>(
bit_cast<uint4>(a),
bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])),
bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7])));
psum.data[0] = bit_cast<float>(out1.x);
psum.data[1] = bit_cast<float>(out1.y);
psum.data[2] = bit_cast<float>(out1.z);
psum.data[3] = bit_cast<float>(out1.w);
psum.data[4] = bit_cast<float>(out2.x);
psum.data[5] = bit_cast<float>(out2.y);
psum.data[6] = bit_cast<float>(out2.z);
psum.data[7] = bit_cast<float>(out2.w);
kernels::bit_cast<uint4>(a),
kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])),
kernels::bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7])));
psum.data[0] = kernels::bit_cast<float>(out1.x);
psum.data[1] = kernels::bit_cast<float>(out1.y);
psum.data[2] = kernels::bit_cast<float>(out1.z);
psum.data[3] = kernels::bit_cast<float>(out1.w);
psum.data[4] = kernels::bit_cast<float>(out2.x);
psum.data[5] = kernels::bit_cast<float>(out2.y);
psum.data[6] = kernels::bit_cast<float>(out2.z);
psum.data[7] = kernels::bit_cast<float>(out2.w);
return psum;
}
......
......@@ -573,7 +573,7 @@ static half2 int2half2_fast_8192(int x, int y) {
ival = ival >> 4;
// (val & 0x03FF03FF) ^ 0x76007600
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
}
// val in [-4096, 4095], steps of 8, round to nearest
__device__ __forceinline__
......@@ -590,7 +590,7 @@ static half2 int2half2_fast_4096_rn(int x, int y) {
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
}
// val in [-512, 511]
__device__ __forceinline__
......@@ -602,7 +602,7 @@ static half2 int2half2_fast_512(int x, int y) {
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
}
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -1674,7 +1674,7 @@ public:
const int col = n * INSN_N + laneId / 16 * 8; // lane 0-15: n*16+0, lane 16-31: n*16+8
uint4 tmp;
ldmatrix(shmem + col, tmp);
return bit_cast<packed_fpsum_t>(tmp);
return kernels::bit_cast<packed_fpsum_t>(tmp);
}
__device__ __forceinline__
......@@ -1813,30 +1813,30 @@ public:
__device__ __forceinline__
static packed_qkv_t pack_q(packed_fpsum_t input) {
packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(input.data[0]));
output.y = bit_cast<int>(convert_half2(input.data[1]));
output.z = bit_cast<int>(convert_half2(input.data[2]));
output.w = bit_cast<int>(convert_half2(input.data[3]));
output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = kernels::bit_cast<int>(convert_half2(input.data[1]));
output.z = kernels::bit_cast<int>(convert_half2(input.data[2]));
output.w = kernels::bit_cast<int>(convert_half2(input.data[3]));
return output;
}
__device__ __forceinline__
static packed_qkv_t pack_k(packed_fpsum_t input) {
packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(input.data[0]));
output.y = bit_cast<int>(convert_half2(input.data[2]));
output.z = bit_cast<int>(convert_half2(input.data[1]));
output.w = bit_cast<int>(convert_half2(input.data[3]));
output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = kernels::bit_cast<int>(convert_half2(input.data[2]));
output.z = kernels::bit_cast<int>(convert_half2(input.data[1]));
output.w = kernels::bit_cast<int>(convert_half2(input.data[3]));
return output;
}
__device__ __forceinline__
static packed_qkv_t pack_v(packed_fpsum_t input) {
packed_qkv_t output;
output.x = bit_cast<int>(convert_half2(movmatrix(input.data[0])));
output.y = bit_cast<int>(convert_half2(movmatrix(input.data[1])));
output.z = bit_cast<int>(convert_half2(movmatrix(input.data[2])));
output.w = bit_cast<int>(convert_half2(movmatrix(input.data[3])));
output.x = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[0])));
output.y = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[1])));
output.z = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[2])));
output.w = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[3])));
return output;
}
......@@ -1867,7 +1867,7 @@ public:
unrolled_loop<WARP_M_TILES>([&]<int m>() ALWAYSINLINE {
unrolled_loop<WARP_N_TILES>([&]<int n>() ALWAYSINLINE {
packed_qkv_t pack = funcPack(fpsum[m * WARP_N_TILES + n]);
mask(pack, bit_cast<uint32_t>(maskVal), m, maxRows - warpId * WARP_M);
mask(pack, kernels::bit_cast<uint32_t>(maskVal), m, maxRows - warpId * WARP_M);
store(&ptrlane[(m * WARP_N_TILES + n) * WARP_SIZE], pack);
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment