Commit 181f4e43 authored by fengzch's avatar fengzch
Browse files

fix compile error

parent 4cdcd76f
...@@ -106,12 +106,12 @@ __global__ void gemv_kernel(const half_t *inputs, ...@@ -106,12 +106,12 @@ __global__ void gemv_kernel(const half_t *inputs,
const int IC, const int IC,
const int OC) { const int OC) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) { // if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
trap_unsupported_arch(); // trap_unsupported_arch();
return; // return;
} // }
#endif // #endif
using half2_t = typename packed_as<half_t, 2>::type; using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float; using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type; using accum2_t = typename packed_as<accum_t, 2>::type;
......
...@@ -608,7 +608,16 @@ public: ...@@ -608,7 +608,16 @@ public:
packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 +
LITELA_HEAD_DIM / 16 + tile_v]; LITELA_HEAD_DIM / 16 + tile_v];
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu __hip_bfloat162 first;
first.x = float(k.data[j].x);
first.y = float(k.data[j].y);
auto temp = half2_t(0, 0);
__hip_bfloat162 sec;
sec.x = float(temp.x);
sec.y = float(temp.y);
auto relu_result = __hmax2(first, sec); // relu
k.data[j].x = float(relu_result.x);
k.data[j].y = float(relu_result.y);
} }
attn_sum = mma_litela(k, v, attn_sum); attn_sum = mma_litela(k, v, attn_sum);
} }
...@@ -632,7 +641,16 @@ public: ...@@ -632,7 +641,16 @@ public:
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k]; packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = {}; packed_fpsum_t v = {};
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu __hip_bfloat162 first;
first.x = float(k.data[j].x);
first.y = float(k.data[j].y);
auto temp = half2_t(0, 0);
__hip_bfloat162 sec;
sec.x = float(temp.x);
sec.y = float(temp.y);
auto relu_result = __hmax2(first, sec); // relu
k.data[j].x = float(relu_result.x);
k.data[j].y = float(relu_result.y);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
...@@ -801,7 +819,7 @@ public: ...@@ -801,7 +819,7 @@ public:
fpsum_warp fpsum; fpsum_warp fpsum;
Base::template load_act_to_fpsum<false>()(args.input + m_offset * args.actualN + n_offset, typename Base::template load_act_to_fpsum<false>()(args.input + m_offset * args.actualN + n_offset,
args.actualN, args.actualN,
args.actualM - m_offset, args.actualM - m_offset,
args.actualN - n_offset, args.actualN - n_offset,
......
#pragma once #pragma once
#include <hip/amd_detail/amd_hip_bf16.h>
#include "common.h" #include "common.h"
...@@ -8,7 +9,7 @@ ...@@ -8,7 +9,7 @@
#include "mma_earlycuda.cuh" #include "mma_earlycuda.cuh"
#pragma nv_diag_suppress 177 #pragma nv_diag_suppress 177
#define __DTK_ARCH__ 1200
#ifdef _MSC_VER #ifdef _MSC_VER
#define ALWAYSINLINE [[msvc::forceinline]] #define ALWAYSINLINE [[msvc::forceinline]]
#else #else
...@@ -208,11 +209,11 @@ public: ...@@ -208,11 +209,11 @@ public:
uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>( uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>(
kernels::bit_cast<uint4>(a), kernels::bit_cast<uint4>(a),
kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])), kernels::bit_cast<uint2>(std::array<half2_t, 2>{b.data[0], b.data[1]}),
kernels::bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3]))); kernels::bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3])));
uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>( uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>(
kernels::bit_cast<uint4>(a), kernels::bit_cast<uint4>(a),
kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])), kernels::bit_cast<uint2>(std::array<half2_t, 2>{b.data[2], b.data[3]}),
kernels::bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7]))); kernels::bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7])));
psum.data[0] = kernels::bit_cast<float>(out1.x); psum.data[0] = kernels::bit_cast<float>(out1.x);
psum.data[1] = kernels::bit_cast<float>(out1.y); psum.data[1] = kernels::bit_cast<float>(out1.y);
...@@ -344,14 +345,28 @@ public: ...@@ -344,14 +345,28 @@ public:
const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE); const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4; const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4;
const int elementIdx = k % WSCALES_PACK_SIZE / 2; const int elementIdx = k % WSCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane); half2 temp;
temp.x = float(block[packIdx].data[elementIdx].x);
temp.y = float(block[packIdx].data[elementIdx].y);
half2 res = __shfl(temp, srcLane);
half2_t result;
result.x = (float)res.x;
result.y = (float)res.y;
return result;
} }
// get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes // get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__ static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) { __device__ __forceinline__ static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) {
const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE); const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4; const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4;
const int elementIdx = k % ASCALES_PACK_SIZE / 2; const int elementIdx = k % ASCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane); half2 temp;
temp.x = float(block[packIdx].data[elementIdx].x);
temp.y = float(block[packIdx].data[elementIdx].y);
half2 res = __shfl(temp, srcLane);
half2_t result;
result.x = (float)res.x;
result.y = (float)res.y;
return result;
} }
struct i2f_normal { struct i2f_normal {
...@@ -897,16 +912,16 @@ constexpr int max_arch() { ...@@ -897,16 +912,16 @@ constexpr int max_arch() {
template<typename kernel, typename... T> template<typename kernel, typename... T>
__global__ static void invoke_kernel(T... args) { __global__ static void invoke_kernel(T... args) {
#ifdef __CUDA_ARCH__ // #ifdef __CUDA_ARCH__
if constexpr (__CUDA_ARCH__ >= min_arch<kernel>() && __CUDA_ARCH__ <= max_arch<kernel>()) { // if constexpr (__CUDA_ARCH__ >= min_arch<kernel>() && __CUDA_ARCH__ <= max_arch<kernel>()) {
kernel()(args...); // kernel()(args...);
} else { // } else {
trap_unsupported_arch(); // trap_unsupported_arch();
} // }
#else // #else
// ??? // ???
kernel()(args...); kernel()(args...);
#endif // #endif
} }
template<typename T> template<typename T>
......
...@@ -122,8 +122,8 @@ public: ...@@ -122,8 +122,8 @@ public:
for (int mask = 2; mask > 0; mask /= 2) { for (int mask = 2; mask > 0; mask /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) { for (int i = 0; i < NUM_GROUPS; i++) {
maxvalue[0][i] = __hmax(maxvalue[0][i], __shfl_xor(maxvalue[0][i], mask)); maxvalue[0][i] = __hmax(maxvalue[0][i], __shfl_xor(float(maxvalue[0][i]), mask));
maxvalue[1][i] = __hmax(maxvalue[1][i], __shfl_xor(maxvalue[1][i], mask)); maxvalue[1][i] = __hmax(maxvalue[1][i], __shfl_xor(float(maxvalue[1][i]), mask));
} }
} }
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now // lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now
...@@ -197,56 +197,57 @@ public: ...@@ -197,56 +197,57 @@ public:
int ida, int ida,
int idb) { int idb) {
packed_f32psum_t out; packed_f32psum_t out;
asm volatile( // asm volatile(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 " // "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, " // "{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, " // "{%4, %5, %6, %7}, "
"{%8, %9}, " // "{%8, %9}, "
"{%10, %11, %12, %13}, " // "{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, " // "{%14}, {%15, %16}, "
"{%17}, {%18, %19};" // "{%17}, {%18, %19};"
: "=f"(out.data[0]), "=f"(out.data[1]), "=f"(out.data[2]), "=f"(out.data[3]) // : "=f"(out.data[0]), "=f"(out.data[1]), "=f"(out.data[2]), "=f"(out.data[3])
: "r"(act.x), // : "r"(act.x),
"r"(act.y), // "r"(act.y),
"r"(act.z), // "r"(act.z),
"r"(act.w), // "r"(act.w),
"r"(wgt.x), // "r"(wgt.x),
"r"(wgt.y), // "r"(wgt.y),
"f"(psum.data[0]), // "f"(psum.data[0]),
"f"(psum.data[1]), // "f"(psum.data[1]),
"f"(psum.data[2]), // "f"(psum.data[2]),
"f"(psum.data[3]), // "f"(psum.data[3]),
"r"(amscale), // "r"(amscale),
"n"(0), // "n"(0),
"h"((short)ida), // "h"((short)ida),
"r"(wmscale), // "r"(wmscale),
"n"(0), // "n"(0),
"h"((short)(idb * 2))); // "h"((short)(idb * 2)));
asm volatile( // asm volatile(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 " // "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, " // "{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, " // "{%4, %5, %6, %7}, "
"{%8, %9}, " // "{%8, %9}, "
"{%10, %11, %12, %13}, " // "{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, " // "{%14}, {%15, %16}, "
"{%17}, {%18, %19};" // "{%17}, {%18, %19};"
: "=f"(out.data[4]), "=f"(out.data[5]), "=f"(out.data[6]), "=f"(out.data[7]) // : "=f"(out.data[4]), "=f"(out.data[5]), "=f"(out.data[6]), "=f"(out.data[7])
: "r"(act.x), // : "r"(act.x),
"r"(act.y), // "r"(act.y),
"r"(act.z), // "r"(act.z),
"r"(act.w), // "r"(act.w),
"r"(wgt.z), // "r"(wgt.z),
"r"(wgt.w), // "r"(wgt.w),
"f"(psum.data[4]), // "f"(psum.data[4]),
"f"(psum.data[5]), // "f"(psum.data[5]),
"f"(psum.data[6]), // "f"(psum.data[6]),
"f"(psum.data[7]), // "f"(psum.data[7]),
"r"(amscale), // "r"(amscale),
"n"(0), // "n"(0),
"h"((short)ida), // "h"((short)ida),
"r"(wmscale), // "r"(wmscale),
"n"(0), // "n"(0),
"h"((short)(idb * 2 + 1))); // "h"((short)(idb * 2 + 1)));
std::cout << __func__ << "mma_fp4 is not implemented for HIP yet[asm error!!!]" << std::endl;
return out; return out;
} }
...@@ -465,11 +466,11 @@ public: ...@@ -465,11 +466,11 @@ public:
} }
#pragma unroll #pragma unroll
for (int mask = 2; mask > 0; mask /= 2) { for (int mask = 2; mask > 0; mask /= 2) {
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor(maxvalue[0], mask)); maxvalue[0] = __hmax(maxvalue[0], __shfl_xor(float(maxvalue[0]), mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor(maxvalue[1], mask)); maxvalue[1] = __hmax(maxvalue[1], __shfl_xor(float(maxvalue[1]), mask));
} }
maxvalue[0] = __shfl_sync(~0, maxvalue[0], laneId / 4 * 4); maxvalue[0] = __shfl(float(maxvalue[0]), laneId / 4 * 4);
maxvalue[1] = __shfl_sync(~0, maxvalue[1], laneId / 4 * 4); maxvalue[1] = __shfl(float(maxvalue[1]), laneId / 4 * 4);
float scale[2]; float scale[2];
// scale[0] = float(maxvalue[0]) / QVALUE_MAX; // scale[0] = float(maxvalue[0]) / QVALUE_MAX;
...@@ -577,14 +578,14 @@ public: ...@@ -577,14 +578,14 @@ public:
for (int mask = NUM_PACKS_PER_ROW / 2; mask > 0; mask /= 2) { for (int mask = NUM_PACKS_PER_ROW / 2; mask > 0; mask /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) { for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor(maxvalue[i], mask)); maxvalue[i] = __hmax(maxvalue[i], __shfl_xor(float(maxvalue[i]), mask));
} }
} }
// broadcast (max) // broadcast (max)
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) { for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __shfl_sync(~0, maxvalue[i], laneId / NUM_PACKS_PER_ROW * NUM_PACKS_PER_ROW); maxvalue[i] = __shfl(float(maxvalue[i]), laneId / NUM_PACKS_PER_ROW * NUM_PACKS_PER_ROW);
} }
// quantize // quantize
...@@ -1150,7 +1151,7 @@ public: ...@@ -1150,7 +1151,7 @@ public:
fpsum_warp fpsum; fpsum_warp fpsum;
Base::template load_act_to_fpsum<fuse_glu>()(args.input + m_offset * args.actualN + n_offset, typename Base::template load_act_to_fpsum<fuse_glu>()(args.input + m_offset * args.actualN + n_offset,
args.actualN, args.actualN,
args.actualM - m_offset, args.actualM - m_offset,
args.actualN - n_offset, args.actualN - n_offset,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <cstdint> #include <cstdint>
#include "common.h" #include "common.h"
#define __DTK_ARCH__ 1200
// only supports cuda 12.5+ // only supports cuda 12.5+
namespace nunchaku::kernels { namespace nunchaku::kernels {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <cstdint> #include <cstdint>
#include "common.h" #include "common.h"
#define __DTK_ARCH__ 1200
// cuda 12.4- does not support "C" constraint in inline assembly :( // cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now // use explicit specialization for now
...@@ -122,14 +122,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe ...@@ -122,14 +122,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
static constexpr int K = 64; static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( // asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 " // "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," // "{%4, %5, %6, %7},"
"{%8, %9}," // "{%8, %9},"
"{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else #else
asm volatile("{" asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
...@@ -176,14 +176,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe ...@@ -176,14 +176,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
static constexpr int K = 64; static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( // asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 " // "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3}," // "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," // "{%4, %5, %6, %7},"
"{%8, %9}," // "{%8, %9},"
"{%10, %11, %12, %13};\n" // "{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) // : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K)); // : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else #else
asm volatile("{" asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;" ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment