Commit 9316940c authored by fengzch's avatar fengzch
Browse files

fix: compile misc_kernels.cu complete

parent 038b8469
...@@ -225,6 +225,11 @@ __device__ inline T_OUT cuda_cast(T_IN val) { ...@@ -225,6 +225,11 @@ __device__ inline T_OUT cuda_cast(T_IN val) {
return val; return val;
} }
template<>
__device__ inline __hip_bfloat16 cuda_cast<__hip_bfloat16, long>(long val) {
return (long long)val;
}
template<> template<>
__device__ inline float2 cuda_cast<float2, int2>(int2 val) { __device__ inline float2 cuda_cast<float2, int2>(int2 val) {
return make_float2(val.x, val.y); return make_float2(val.x, val.y);
...@@ -268,7 +273,8 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val) { ...@@ -268,7 +273,8 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val) {
}; };
fp16 = val; fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); int16 = int16_in;
// asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0]; return int8[0];
} }
......
...@@ -101,7 +101,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -101,7 +101,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool>; bool>;
if (shmem >= 24 * 1024) { if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); checkCUDA(cudaFuncSetAttribute(reinterpret_cast<const void*>(func), cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
} }
assert(alpha == 1.0f); assert(alpha == 1.0f);
......
...@@ -204,7 +204,14 @@ public: ...@@ -204,7 +204,14 @@ public:
#pragma unroll #pragma unroll
for (int mask = 32 / 2; mask > 0; mask /= 2) { for (int mask = 32 / 2; mask > 0; mask /= 2) {
maxvalue2 = __hmax2(maxvalue2, __shfl_xor(maxvalue2, mask)); __half2 m;
m.x = float(maxvalue2.x);
m.y = float(maxvalue2.y);
auto temp = __shfl_xor(m, mask);
__hip_bfloat162 n;
n.x = float(temp.x);
n.y = float(temp.y);
maxvalue2 = __hmax2(maxvalue2, n);
} }
return __hmax(maxvalue2.x, maxvalue2.y); return __hmax(maxvalue2.x, maxvalue2.y);
...@@ -243,9 +250,9 @@ public: ...@@ -243,9 +250,9 @@ public:
const int bm = blockIdx.x / (BLOCK_M / WARP_M); const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int gemmWarpId = blockIdx.x % (BLOCK_M / WARP_M); const int gemmWarpId = blockIdx.x % (BLOCK_M / WARP_M);
__shared__ alignas(128) half_t oscale_shmem[WARP_M]; __shared__ __attribute__((aligned(128))) half_t oscale_shmem[WARP_M];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M]; // __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[NUM_WARPS][512]; __shared__ __attribute__((aligned(128))) uint8_t tmp_shmem[NUM_WARPS][512];
const int K2 = fuse_glu ? K / 2 : K; const int K2 = fuse_glu ? K / 2 : K;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment