"profiler/vscode:/vscode.git/clone" did not exist on "d8f1458f448f4509d950ef04adc15eda622a9c5d"
Commit 0a7c8614 authored by fengzch-das's avatar fengzch-das
Browse files

Revert "hipify code"

This reverts commit 1a8114bf
parent 1a8114bf
Pipeline #3050 failed with stages
in 0 seconds
#include "hip/hip_runtime.h"
#include "gemm_w4a4_launch.cuh" #include "gemm_w4a4_launch.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
...@@ -85,7 +84,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -85,7 +84,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// (test_sizeof<decltype(args)>(), ...); // (test_sizeof<decltype(args)>(), ...);
// }, args); // }, args);
// constexpr bool FP4_AVAILABLE = __DTK_ARCH__ >= 1200; // constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if constexpr (!USE_FP4) { if constexpr (!USE_FP4) {
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() { dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
...@@ -102,12 +101,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -102,12 +101,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool>; bool>;
if (shmem >= 24 * 1024) { if (shmem >= 24 * 1024) {
checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, shmem)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
} }
assert(alpha == 1.0f); assert(alpha == 1.0f);
hipLaunchKernelGGL(( func), dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), shmem, getCurrentHIPStreamMasqueradingAsCUDA(), func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(), act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(), wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(), ascales.data_ptr<packed_ascale_t>(),
...@@ -118,7 +117,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -118,7 +117,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
args, args,
swapBlockMN, swapBlockMN,
false); false);
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
}); });
return; return;
} }
...@@ -141,13 +140,13 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -141,13 +140,13 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool>; bool>;
if (shmem >= 24 * 1024) { if (shmem >= 24 * 1024) {
checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, shmem)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
} }
assert(ascales.dtype() == Tensor::FP8_E4M3); assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3); assert(wscales.dtype() == Tensor::FP8_E4M3);
hipLaunchKernelGGL(( func), dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), shmem, getCurrentHIPStreamMasqueradingAsCUDA(), func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(), act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(), wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(), ascales.data_ptr<packed_amscale_t>(),
...@@ -159,7 +158,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4( ...@@ -159,7 +158,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
args, args,
swapBlockMN, swapBlockMN,
false); false);
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
}); });
return; return;
...@@ -442,10 +441,10 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) ...@@ -442,10 +441,10 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
BLOCK_SIZE = 128; BLOCK_SIZE = 128;
} }
hipLaunchKernelGGL(( invoke_kernel<typename Epilogue::vk_mul_q_kernel>) invoke_kernel<typename Epilogue::vk_mul_q_kernel>
, dim3(dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size)), dim3(BLOCK_SIZE), 0, getCurrentHIPStreamMasqueradingAsCUDA(), <<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>(
q.data_ptr<half_t>(), vk.data_ptr<float>(), 1e-6f, num_tokens); q.data_ptr<half_t>(), vk.data_ptr<float>(), 1e-6f, num_tokens);
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
} }
template<typename Config, bool USE_FP4> template<typename Config, bool USE_FP4>
...@@ -496,12 +495,12 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input ...@@ -496,12 +495,12 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
auto func = invoke_kernel<kernel, typename kernel::Arguments>; auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, // log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel())); // input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
hipLaunchKernelGGL(( func), dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), kernel::SHMEM_SIZE, getCurrentHIPStreamMasqueradingAsCUDA(), func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{ typename kernel::Arguments{
.input = input.data_ptr<half_t>(), .input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr, .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
...@@ -516,7 +515,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input ...@@ -516,7 +515,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
.actualN = actualN, .actualN = actualN,
.alwaysfalse = false, .alwaysfalse = false,
}); });
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
}); });
// }); // });
} }
...@@ -540,9 +539,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o ...@@ -540,9 +539,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
assert(oscales.numel() == M * K / GEMM::WARP_K); assert(oscales.numel() == M * K / GEMM::WARP_K);
dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K); dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
hipLaunchKernelGGL(( invoke_kernel<typename GEMM::quantize_w4a4_act_kernel>), dim3(grid), dim3(GEMM::WARP_SIZE), 0, getCurrentHIPStreamMasqueradingAsCUDA(), invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(), output.data_ptr<packed_act_t>(), oscales.data_ptr<packed_ascale_t>(), K); input.data_ptr<half_t>(), output.data_ptr<packed_act_t>(), oscales.data_ptr<packed_ascale_t>(), K);
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
} }
template<typename Config, bool USE_FP4> template<typename Config, bool USE_FP4>
...@@ -565,9 +564,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o ...@@ -565,9 +564,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
assert(oscales.numel() == N * K / GEMM::WARP_K); assert(oscales.numel() == N * K / GEMM::WARP_K);
dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K); dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
hipLaunchKernelGGL(( invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel>), dim3(grid), dim3(GEMM::WARP_SIZE), 0, getCurrentHIPStreamMasqueradingAsCUDA(), invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(), output.data_ptr<packed_wgt_t>(), oscales.data_ptr<packed_wscale_t>(), K); input.data_ptr<half_t>(), output.data_ptr<packed_wgt_t>(), oscales.data_ptr<packed_wscale_t>(), K);
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
#include "hip/hip_runtime.h"
#include "zgemm.h" #include "zgemm.h"
#include "gemm_w4a4.cuh" #include "gemm_w4a4.cuh"
#include "epilogues.cuh" #include "epilogues.cuh"
...@@ -22,11 +21,11 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k ...@@ -22,11 +21,11 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
auto func = invoke_kernel<kernel, typename kernel::Arguments>; auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N); dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
hipLaunchKernelGGL(( func), dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), kernel::SHMEM_SIZE, getCurrentHIPStreamMasqueradingAsCUDA(), func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{.input = input.data_ptr<GEMM::half_t>(), typename kernel::Arguments{.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(), .output = output.data_ptr<GEMM::half_t>(),
.M = M, .M = M,
...@@ -39,7 +38,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k ...@@ -39,7 +38,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(), .rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6, .epsilon = 1e-6,
}}); }});
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
} }
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) { void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
...@@ -60,11 +59,11 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n ...@@ -60,11 +59,11 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
auto func = invoke_kernel<kernel, typename kernel::Arguments>; auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N); dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
hipLaunchKernelGGL(( func), dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), kernel::SHMEM_SIZE, getCurrentHIPStreamMasqueradingAsCUDA(), func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{ typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(), .input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(), .output = output.data_ptr<GEMM::half_t>(),
...@@ -84,7 +83,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n ...@@ -84,7 +83,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
.strideHead_v = .strideHead_v =
int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)), int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}}); }});
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
} }
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
#include "hip/hip_runtime.h"
#include "zgemm.h" #include "zgemm.h"
#include "gemm_w8a8.cuh" #include "gemm_w8a8.cuh"
...@@ -27,14 +26,14 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl ...@@ -27,14 +26,14 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
auto func = auto func =
invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>; invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, 92160)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));
hipLaunchKernelGGL(( func), dim3(grid), dim3(block), kernel::smemSize(M, K), 0, input.data_ptr<GEMM::half_t>(), func<<<grid, block, kernel::smemSize(M, K)>>>(input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(), output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(), oscales.data_ptr<GEMM::packed_ascale_t>(),
K, K,
false); false);
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
}; };
if (fuse_glu) { if (fuse_glu) {
...@@ -75,8 +74,8 @@ void gemm_w8a8(Tensor act, // [M, K] ...@@ -75,8 +74,8 @@ void gemm_w8a8(Tensor act, // [M, K]
std::swap(grid.x, grid.y); std::swap(grid.x, grid.y);
} }
hipLaunchKernelGGL(( invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>) invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>
, dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), 0, 0, act.data_ptr<GEMM::packed_act_t>(), <<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(), wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(), ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(), wscales.data_ptr<GEMM::packed_wscale_t>(),
...@@ -87,7 +86,7 @@ void gemm_w8a8(Tensor act, // [M, K] ...@@ -87,7 +86,7 @@ void gemm_w8a8(Tensor act, // [M, K]
args, args,
swapBlockMN, swapBlockMN,
false); false);
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
}; };
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) { auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
...@@ -148,7 +147,7 @@ void gemm_w8a8_fuse_litela( ...@@ -148,7 +147,7 @@ void gemm_w8a8_fuse_litela(
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>(); epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_vk = out_vk.data_ptr<float>(); epilogueArgs.out_vk = out_vk.data_ptr<float>();
checkCUDA(hipMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize())); checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>, auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *, const GEMM::packed_act_t *,
...@@ -161,7 +160,7 @@ void gemm_w8a8_fuse_litela( ...@@ -161,7 +160,7 @@ void gemm_w8a8_fuse_litela(
bool, bool,
bool>; bool>;
checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE)); checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N); dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
...@@ -170,7 +169,7 @@ void gemm_w8a8_fuse_litela( ...@@ -170,7 +169,7 @@ void gemm_w8a8_fuse_litela(
std::swap(grid.x, grid.y); std::swap(grid.x, grid.y);
} }
hipLaunchKernelGGL(( func), dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), Epilogue::SHMEM_SIZE, 0, func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE>>>(
act.data_ptr<GEMM::packed_act_t>(), act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(), wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(), ascales.data_ptr<GEMM::packed_ascale_t>(),
...@@ -180,14 +179,14 @@ void gemm_w8a8_fuse_litela( ...@@ -180,14 +179,14 @@ void gemm_w8a8_fuse_litela(
swapBlockMN, swapBlockMN,
false false
); );
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
hipLaunchKernelGGL(( invoke_kernel<Epilogue::vk_mul_q_kernel>), dim3(dim3(batch_m / 128, num_heads, batch_size)), dim3(128), 0, 0, invoke_kernel<Epilogue::vk_mul_q_kernel><<<dim3(batch_m / 128, num_heads, batch_size), 128>>>(
out_q.data_ptr<GEMM::half_t>(), out_q.data_ptr<GEMM::half_t>(),
out_vk.data_ptr<float>(), out_vk.data_ptr<float>(),
1e-6f 1e-6f
); );
checkCUDA(hipGetLastError()); checkCUDA(cudaGetLastError());
} }
#endif #endif
......
#include "hip/hip_runtime.h"
#pragma once #pragma once
#include "gemm_base.cuh" #include "gemm_base.cuh"
...@@ -439,7 +438,7 @@ public: ...@@ -439,7 +438,7 @@ public:
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N] // out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue> template<typename Epilogue>
struct gemm_w8a8_kernel { struct gemm_w8a8_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __hip_bfloat16> ? 800 : 750; static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__ void operator()(const packed_act_t *act, __device__ void operator()(const packed_act_t *act,
const packed_wgt_t *wgt, const packed_wgt_t *wgt,
const packed_ascale_t *ascales, const packed_ascale_t *ascales,
......
...@@ -35,7 +35,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>; ...@@ -35,7 +35,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) { __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d; uint2 d;
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1}," "{%0, %1},"
"{%2, %3, %4, %5}," "{%2, %3, %4, %5},"
...@@ -66,7 +66,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 ...@@ -66,7 +66,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2
template<bool is_bf16> template<bool is_bf16>
__device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) { __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 " asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3}," "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," "{%4, %5, %6, %7},"
...@@ -110,7 +110,7 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, ...@@ -110,7 +110,7 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b,
uint4 d; uint4 d;
static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32; static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32;
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 " asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3}," "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," "{%4, %5, %6, %7},"
......
...@@ -36,7 +36,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>; ...@@ -36,7 +36,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) { __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d; uint2 d;
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1}," "{%0, %1},"
"{%2, %3, %4, %5}," "{%2, %3, %4, %5},"
...@@ -67,7 +67,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 ...@@ -67,7 +67,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2
template<bool is_bf16> template<bool is_bf16>
__device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete; __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete;
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<> template<>
__device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) { __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
...@@ -85,7 +85,7 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 ...@@ -85,7 +85,7 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2
template<> template<>
__device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) { __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
uint4 d; uint4 d;
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}," "{%0, %1, %2, %3},"
"{%4, %5, %6, %7}," "{%4, %5, %6, %7},"
...@@ -121,7 +121,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe ...@@ -121,7 +121,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
uint4 d; uint4 d;
static constexpr int K = 64; static constexpr int K = 64;
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 " "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3}," "{%0, %1, %2, %3},"
...@@ -175,7 +175,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe ...@@ -175,7 +175,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
uint4 d; uint4 d;
static constexpr int K = 64; static constexpr int K = 64;
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 " "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3}," "{%0, %1, %2, %3},"
......
#include "hip/hip_runtime.h"
#pragma once #pragma once
#include "common.h" #include "common.h"
...@@ -10,7 +9,7 @@ inline void TORCH_CHECK(bool cond, const std::string &msg = "") { ...@@ -10,7 +9,7 @@ inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
} }
template<typename T> template<typename T>
inline void C10_HIP_CHECK(T ret) { inline void C10_CUDA_CHECK(T ret) {
return checkCUDA(ret); return checkCUDA(ret);
} }
...@@ -35,16 +34,16 @@ namespace cuda { ...@@ -35,16 +34,16 @@ namespace cuda {
using ::getCurrentDeviceProperties; using ::getCurrentDeviceProperties;
struct StreamWrapper { struct StreamWrapper {
hipStream_t st; cudaStream_t st;
hipStream_t stream() const { cudaStream_t stream() const {
return st; return st;
} }
}; };
inline StreamWrapper getCurrentHIPStreamMasqueradingAsCUDA() { inline StreamWrapper getCurrentCUDAStream() {
return StreamWrapper(::getCurrentHIPStreamMasqueradingAsCUDA()); return StreamWrapper(::getCurrentCUDAStream());
} }
struct HIPGuardMasqueradingAsCUDA { struct CUDAGuard {
int dev; int dev;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment