Commit 181f4e43 authored by fengzch's avatar fengzch
Browse files

fix compile error

parent 4cdcd76f
......@@ -106,12 +106,12 @@ __global__ void gemv_kernel(const half_t *inputs,
const int IC,
const int OC) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
trap_unsupported_arch();
return;
}
#endif
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
// trap_unsupported_arch();
// return;
// }
// #endif
using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type;
......
......@@ -608,7 +608,16 @@ public:
packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 +
LITELA_HEAD_DIM / 16 + tile_v];
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
__hip_bfloat162 first;
first.x = float(k.data[j].x);
first.y = float(k.data[j].y);
auto temp = half2_t(0, 0);
__hip_bfloat162 sec;
sec.x = float(temp.x);
sec.y = float(temp.y);
auto relu_result = __hmax2(first, sec); // relu
k.data[j].x = float(relu_result.x);
k.data[j].y = float(relu_result.y);
}
attn_sum = mma_litela(k, v, attn_sum);
}
......@@ -632,7 +641,16 @@ public:
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = {};
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
__hip_bfloat162 first;
first.x = float(k.data[j].x);
first.y = float(k.data[j].y);
auto temp = half2_t(0, 0);
__hip_bfloat162 sec;
sec.x = float(temp.x);
sec.y = float(temp.y);
auto relu_result = __hmax2(first, sec); // relu
k.data[j].x = float(relu_result.x);
k.data[j].y = float(relu_result.y);
}
#pragma unroll
for (int i = 0; i < 4; i++) {
......@@ -801,7 +819,7 @@ public:
fpsum_warp fpsum;
Base::template load_act_to_fpsum<false>()(args.input + m_offset * args.actualN + n_offset,
typename Base::template load_act_to_fpsum<false>()(args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
......
#pragma once
#include <hip/amd_detail/amd_hip_bf16.h>
#include "common.h"
......@@ -8,7 +9,7 @@
#include "mma_earlycuda.cuh"
#pragma nv_diag_suppress 177
#define __DTK_ARCH__ 1200
#ifdef _MSC_VER
#define ALWAYSINLINE [[msvc::forceinline]]
#else
......@@ -208,11 +209,11 @@ public:
uint4 out1 = mma_m16n8k16_f32f16f16f32<is_bf16>(
kernels::bit_cast<uint4>(a),
kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[0], b.data[1])),
kernels::bit_cast<uint2>(std::array<half2_t, 2>{b.data[0], b.data[1]}),
kernels::bit_cast<uint4>(float4(psum.data[0], psum.data[1], psum.data[2], psum.data[3])));
uint4 out2 = mma_m16n8k16_f32f16f16f32<is_bf16>(
kernels::bit_cast<uint4>(a),
kernels::bit_cast<uint2>(std::array<half2_t, 2>(b.data[2], b.data[3])),
kernels::bit_cast<uint2>(std::array<half2_t, 2>{b.data[2], b.data[3]}),
kernels::bit_cast<uint4>(float4(psum.data[4], psum.data[5], psum.data[6], psum.data[7])));
psum.data[0] = kernels::bit_cast<float>(out1.x);
psum.data[1] = kernels::bit_cast<float>(out1.y);
......@@ -344,14 +345,28 @@ public:
const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4;
const int elementIdx = k % WSCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
half2 temp;
temp.x = float(block[packIdx].data[elementIdx].x);
temp.y = float(block[packIdx].data[elementIdx].y);
half2 res = __shfl(temp, srcLane);
half2_t result;
result.x = (float)res.x;
result.y = (float)res.y;
return result;
}
// get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__ static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) {
const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4;
const int elementIdx = k % ASCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
half2 temp;
temp.x = float(block[packIdx].data[elementIdx].x);
temp.y = float(block[packIdx].data[elementIdx].y);
half2 res = __shfl(temp, srcLane);
half2_t result;
result.x = (float)res.x;
result.y = (float)res.y;
return result;
}
struct i2f_normal {
......@@ -897,16 +912,16 @@ constexpr int max_arch() {
template<typename kernel, typename... T>
__global__ static void invoke_kernel(T... args) {
#ifdef __CUDA_ARCH__
if constexpr (__CUDA_ARCH__ >= min_arch<kernel>() && __CUDA_ARCH__ <= max_arch<kernel>()) {
kernel()(args...);
} else {
trap_unsupported_arch();
}
#else
// #ifdef __CUDA_ARCH__
// if constexpr (__CUDA_ARCH__ >= min_arch<kernel>() && __CUDA_ARCH__ <= max_arch<kernel>()) {
// kernel()(args...);
// } else {
// trap_unsupported_arch();
// }
// #else
// ???
kernel()(args...);
#endif
// #endif
}
template<typename T>
......
......@@ -122,8 +122,8 @@ public:
for (int mask = 2; mask > 0; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) {
maxvalue[0][i] = __hmax(maxvalue[0][i], __shfl_xor(maxvalue[0][i], mask));
maxvalue[1][i] = __hmax(maxvalue[1][i], __shfl_xor(maxvalue[1][i], mask));
maxvalue[0][i] = __hmax(maxvalue[0][i], __shfl_xor(float(maxvalue[0][i]), mask));
maxvalue[1][i] = __hmax(maxvalue[1][i], __shfl_xor(float(maxvalue[1][i]), mask));
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now
......@@ -197,56 +197,57 @@ public:
int ida,
int idb) {
packed_f32psum_t out;
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
: "=f"(out.data[0]), "=f"(out.data[1]), "=f"(out.data[2]), "=f"(out.data[3])
: "r"(act.x),
"r"(act.y),
"r"(act.z),
"r"(act.w),
"r"(wgt.x),
"r"(wgt.y),
"f"(psum.data[0]),
"f"(psum.data[1]),
"f"(psum.data[2]),
"f"(psum.data[3]),
"r"(amscale),
"n"(0),
"h"((short)ida),
"r"(wmscale),
"n"(0),
"h"((short)(idb * 2)));
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
: "=f"(out.data[4]), "=f"(out.data[5]), "=f"(out.data[6]), "=f"(out.data[7])
: "r"(act.x),
"r"(act.y),
"r"(act.z),
"r"(act.w),
"r"(wgt.z),
"r"(wgt.w),
"f"(psum.data[4]),
"f"(psum.data[5]),
"f"(psum.data[6]),
"f"(psum.data[7]),
"r"(amscale),
"n"(0),
"h"((short)ida),
"r"(wmscale),
"n"(0),
"h"((short)(idb * 2 + 1)));
// asm volatile(
// "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
// "{%0, %1, %2, %3}, "
// "{%4, %5, %6, %7}, "
// "{%8, %9}, "
// "{%10, %11, %12, %13}, "
// "{%14}, {%15, %16}, "
// "{%17}, {%18, %19};"
// : "=f"(out.data[0]), "=f"(out.data[1]), "=f"(out.data[2]), "=f"(out.data[3])
// : "r"(act.x),
// "r"(act.y),
// "r"(act.z),
// "r"(act.w),
// "r"(wgt.x),
// "r"(wgt.y),
// "f"(psum.data[0]),
// "f"(psum.data[1]),
// "f"(psum.data[2]),
// "f"(psum.data[3]),
// "r"(amscale),
// "n"(0),
// "h"((short)ida),
// "r"(wmscale),
// "n"(0),
// "h"((short)(idb * 2)));
// asm volatile(
// "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
// "{%0, %1, %2, %3}, "
// "{%4, %5, %6, %7}, "
// "{%8, %9}, "
// "{%10, %11, %12, %13}, "
// "{%14}, {%15, %16}, "
// "{%17}, {%18, %19};"
// : "=f"(out.data[4]), "=f"(out.data[5]), "=f"(out.data[6]), "=f"(out.data[7])
// : "r"(act.x),
// "r"(act.y),
// "r"(act.z),
// "r"(act.w),
// "r"(wgt.z),
// "r"(wgt.w),
// "f"(psum.data[4]),
// "f"(psum.data[5]),
// "f"(psum.data[6]),
// "f"(psum.data[7]),
// "r"(amscale),
// "n"(0),
// "h"((short)ida),
// "r"(wmscale),
// "n"(0),
// "h"((short)(idb * 2 + 1)));
std::cout << __func__ << "mma_fp4 is not implemented for HIP yet[asm error!!!]" << std::endl;
return out;
}
......@@ -465,11 +466,11 @@ public:
}
#pragma unroll
for (int mask = 2; mask > 0; mask /= 2) {
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor(maxvalue[0], mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor(maxvalue[1], mask));
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor(float(maxvalue[0]), mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor(float(maxvalue[1]), mask));
}
maxvalue[0] = __shfl_sync(~0, maxvalue[0], laneId / 4 * 4);
maxvalue[1] = __shfl_sync(~0, maxvalue[1], laneId / 4 * 4);
maxvalue[0] = __shfl(float(maxvalue[0]), laneId / 4 * 4);
maxvalue[1] = __shfl(float(maxvalue[1]), laneId / 4 * 4);
float scale[2];
// scale[0] = float(maxvalue[0]) / QVALUE_MAX;
......@@ -577,14 +578,14 @@ public:
for (int mask = NUM_PACKS_PER_ROW / 2; mask > 0; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor(maxvalue[i], mask));
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor(float(maxvalue[i]), mask));
}
}
// broadcast (max)
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __shfl_sync(~0, maxvalue[i], laneId / NUM_PACKS_PER_ROW * NUM_PACKS_PER_ROW);
maxvalue[i] = __shfl(float(maxvalue[i]), laneId / NUM_PACKS_PER_ROW * NUM_PACKS_PER_ROW);
}
// quantize
......@@ -1150,7 +1151,7 @@ public:
fpsum_warp fpsum;
Base::template load_act_to_fpsum<fuse_glu>()(args.input + m_offset * args.actualN + n_offset,
typename Base::template load_act_to_fpsum<fuse_glu>()(args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
......
......@@ -2,7 +2,7 @@
#include <cstdint>
#include "common.h"
#define __DTK_ARCH__ 1200
// only supports cuda 12.5+
namespace nunchaku::kernels {
......
......@@ -2,7 +2,7 @@
#include <cstdint>
#include "common.h"
#define __DTK_ARCH__ 1200
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
......@@ -122,14 +122,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
// asm volatile(
// "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
......@@ -176,14 +176,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
static constexpr int K = 64;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
// asm volatile(
// "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment