Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
1a8114bf
Commit
1a8114bf
authored
Nov 21, 2025
by
fengzch-das
Browse files
hipify code
parent
c0177256
Pipeline
#3049
canceled with stages
Changes
50
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
55 additions
and
50 deletions
+55
-50
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.hip
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.hip
+0
-0
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.hip
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.hip
+0
-0
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.hip
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.hip
+0
-0
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+18
-17
src/kernels/zgemm/gemm_w4a4_test.hip
src/kernels/zgemm/gemm_w4a4_test.hip
+7
-6
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+2
-1
src/kernels/zgemm/gemm_w8a8.hip
src/kernels/zgemm/gemm_w8a8.hip
+13
-12
src/kernels/zgemm/mma.cuh
src/kernels/zgemm/mma.cuh
+3
-3
src/kernels/zgemm/mma_earlycuda.cuh
src/kernels/zgemm/mma_earlycuda.cuh
+5
-5
src/pytorch_compat.h
src/pytorch_compat.h
+7
-6
No files found.
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.
cu
→
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.
hip
View file @
1a8114bf
File moved
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.
cu
→
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.
hip
View file @
1a8114bf
File moved
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.
cu
→
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.
hip
View file @
1a8114bf
File moved
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
1a8114bf
#include "hip/hip_runtime.h"
#include "gemm_w4a4_launch.cuh"
#include "gemm_w4a4_launch.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
...
@@ -84,7 +85,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -84,7 +85,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// (test_sizeof<decltype(args)>(), ...);
// (test_sizeof<decltype(args)>(), ...);
// }, args);
// }, args);
// constexpr bool FP4_AVAILABLE = __
CUDA
_ARCH__ >= 1200;
// constexpr bool FP4_AVAILABLE = __
DTK
_ARCH__ >= 1200;
if
constexpr
(
!
USE_FP4
)
{
if
constexpr
(
!
USE_FP4
)
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
...
@@ -101,12 +102,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -101,12 +102,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool
>
;
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
checkCUDA
(
hip
FuncSetAttribute
(
func
,
hip
FuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
}
assert
(
alpha
==
1.0
f
);
assert
(
alpha
==
1.0
f
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrent
CUDA
Stream
()
>>>
(
hipLaunchKernelGGL
((
func
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
)
,
shmem
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
act
.
data_ptr
<
packed_act_t
>
(),
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
...
@@ -117,7 +118,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -117,7 +118,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
args
,
args
,
swapBlockMN
,
swapBlockMN
,
false
);
false
);
checkCUDA
(
cuda
GetLastError
());
checkCUDA
(
hip
GetLastError
());
});
});
return
;
return
;
}
}
...
@@ -140,13 +141,13 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -140,13 +141,13 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool
>
;
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
checkCUDA
(
hip
FuncSetAttribute
(
func
,
hip
FuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
}
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrent
CUDA
Stream
()
>>>
(
hipLaunchKernelGGL
((
func
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
)
,
shmem
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
act
.
data_ptr
<
packed_act_t
>
(),
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
...
@@ -158,7 +159,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -158,7 +159,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
args
,
args
,
swapBlockMN
,
swapBlockMN
,
false
);
false
);
checkCUDA
(
cuda
GetLastError
());
checkCUDA
(
hip
GetLastError
());
});
});
return
;
return
;
...
@@ -441,10 +442,10 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
...
@@ -441,10 +442,10 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
BLOCK_SIZE
=
128
;
BLOCK_SIZE
=
128
;
}
}
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
>
hipLaunchKernelGGL
((
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
>
)
<<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
,
0
,
getCurrent
CUDA
Stream
()
>>>
(
,
dim3
(
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
)
)
,
dim3
(
BLOCK_SIZE
)
,
0
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
checkCUDA
(
cuda
GetLastError
());
checkCUDA
(
hip
GetLastError
());
}
}
template
<
typename
Config
,
bool
USE_FP4
>
template
<
typename
Config
,
bool
USE_FP4
>
...
@@ -495,12 +496,12 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
...
@@ -495,12 +496,12 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
checkCUDA
(
hip
FuncSetAttribute
(
func
,
hip
FuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrent
CUDA
Stream
()
>>>
(
hipLaunchKernelGGL
((
func
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
)
,
kernel
::
SHMEM_SIZE
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
typename
kernel
::
Arguments
{
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
...
@@ -515,7 +516,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
...
@@ -515,7 +516,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
.
actualN
=
actualN
,
.
actualN
=
actualN
,
.
alwaysfalse
=
false
,
.
alwaysfalse
=
false
,
});
});
checkCUDA
(
cuda
GetLastError
());
checkCUDA
(
hip
GetLastError
());
});
});
// });
// });
}
}
...
@@ -539,9 +540,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
...
@@ -539,9 +540,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
assert
(
oscales
.
numel
()
==
M
*
K
/
GEMM
::
WARP_K
);
assert
(
oscales
.
numel
()
==
M
*
K
/
GEMM
::
WARP_K
);
dim3
grid
(
M
/
GEMM
::
WARP_M
,
K
/
GEMM
::
WARP_K
);
dim3
grid
(
M
/
GEMM
::
WARP_M
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
>
<<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrent
CUDA
Stream
()
>>>
(
hipLaunchKernelGGL
((
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
>
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
)
,
0
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
checkCUDA
(
cuda
GetLastError
());
checkCUDA
(
hip
GetLastError
());
}
}
template
<
typename
Config
,
bool
USE_FP4
>
template
<
typename
Config
,
bool
USE_FP4
>
...
@@ -564,9 +565,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
...
@@ -564,9 +565,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
assert
(
oscales
.
numel
()
==
N
*
K
/
GEMM
::
WARP_K
);
assert
(
oscales
.
numel
()
==
N
*
K
/
GEMM
::
WARP_K
);
dim3
grid
(
N
/
GEMM
::
WARP_N
,
K
/
GEMM
::
WARP_K
);
dim3
grid
(
N
/
GEMM
::
WARP_N
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
>
<<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrent
CUDA
Stream
()
>>>
(
hipLaunchKernelGGL
((
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
>
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
)
,
0
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
checkCUDA
(
cuda
GetLastError
());
checkCUDA
(
hip
GetLastError
());
}
}
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w4a4_test.
cu
→
src/kernels/zgemm/gemm_w4a4_test.
hip
View file @
1a8114bf
#include "hip/hip_runtime.h"
#include "zgemm.h"
#include "zgemm.h"
#include "gemm_w4a4.cuh"
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
#include "epilogues.cuh"
...
@@ -21,11 +22,11 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
...
@@ -21,11 +22,11 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
checkCUDA(
hip
FuncSetAttribute(func,
hip
FuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrent
CUDA
Stream
()
>>>
(
hipLaunchKernelGGL(( func), dim3(
grid
)
,
dim3(
GEMM::WARP_SIZE * GEMM::NUM_WARPS
)
, kernel::SHMEM_SIZE, getCurrent
HIP
Stream
MasqueradingAsCUDA(),
typename kernel::Arguments{.input = input.data_ptr<GEMM::half_t>(),
typename kernel::Arguments{.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.M = M,
...
@@ -38,7 +39,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
...
@@ -38,7 +39,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
.epsilon = 1e-6,
}});
}});
checkCUDA
(
cuda
GetLastError
());
checkCUDA(
hip
GetLastError());
}
}
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
...
@@ -59,11 +60,11 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
...
@@ -59,11 +60,11 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
checkCUDA(
hip
FuncSetAttribute(func,
hip
FuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrent
CUDA
Stream
()
>>>
(
hipLaunchKernelGGL(( func), dim3(
grid
)
,
dim3(
GEMM::WARP_SIZE * GEMM::NUM_WARPS
)
, kernel::SHMEM_SIZE, getCurrent
HIP
Stream
MasqueradingAsCUDA(),
typename kernel::Arguments{
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
...
@@ -83,7 +84,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
...
@@ -83,7 +84,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
.strideHead_v =
.strideHead_v =
int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}});
}});
checkCUDA
(
cuda
GetLastError
());
checkCUDA(
hip
GetLastError());
}
}
}; // namespace nunchaku::kernels
}; // namespace nunchaku::kernels
src/kernels/zgemm/gemm_w8a8.cuh
View file @
1a8114bf
#include "hip/hip_runtime.h"
#pragma once
#pragma once
#include "gemm_base.cuh"
#include "gemm_base.cuh"
...
@@ -438,7 +439,7 @@ public:
...
@@ -438,7 +439,7 @@ public:
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template
<
typename
Epilogue
>
template
<
typename
Epilogue
>
struct
gemm_w8a8_kernel
{
struct
gemm_w8a8_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__
nv
_bfloat16
>
?
800
:
750
;
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__
hip
_bfloat16
>
?
800
:
750
;
__device__
void
operator
()(
const
packed_act_t
*
act
,
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_ascale_t
*
ascales
,
...
...
src/kernels/zgemm/gemm_w8a8.
cu
→
src/kernels/zgemm/gemm_w8a8.
hip
View file @
1a8114bf
#include "hip/hip_runtime.h"
#include "zgemm.h"
#include "zgemm.h"
#include "gemm_w8a8.cuh"
#include "gemm_w8a8.cuh"
...
@@ -26,14 +27,14 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
...
@@ -26,14 +27,14 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
auto func =
auto func =
invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
92160
));
checkCUDA(
hip
FuncSetAttribute(func,
hip
FuncAttributeMaxDynamicSharedMemorySize, 92160));
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
hipLaunchKernelGGL(( func), dim3(
grid
)
,
dim3(
block
)
, kernel::smemSize(M, K)
, 0,
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(),
output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
K,
K,
false);
false);
checkCUDA
(
cuda
GetLastError
());
checkCUDA(
hip
GetLastError());
};
};
if (fuse_glu) {
if (fuse_glu) {
...
@@ -74,8 +75,8 @@ void gemm_w8a8(Tensor act, // [M, K]
...
@@ -74,8 +75,8 @@ void gemm_w8a8(Tensor act, // [M, K]
std::swap(grid.x, grid.y);
std::swap(grid.x, grid.y);
}
}
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>>
hipLaunchKernelGGL((
invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>
)
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
, dim3(
grid
)
,
dim3(
GEMM::WARP_SIZE * GEMM::NUM_WARPS
), 0, 0,
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
...
@@ -86,7 +87,7 @@ void gemm_w8a8(Tensor act, // [M, K]
...
@@ -86,7 +87,7 @@ void gemm_w8a8(Tensor act, // [M, K]
args,
args,
swapBlockMN,
swapBlockMN,
false);
false);
checkCUDA
(
cuda
GetLastError
());
checkCUDA(
hip
GetLastError());
};
};
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
...
@@ -147,7 +148,7 @@ void gemm_w8a8_fuse_litela(
...
@@ -147,7 +148,7 @@ void gemm_w8a8_fuse_litela(
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_vk = out_vk.data_ptr<float>();
epilogueArgs.out_vk = out_vk.data_ptr<float>();
checkCUDA(
cuda
MemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
checkCUDA(
hip
MemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_act_t *,
...
@@ -160,7 +161,7 @@ void gemm_w8a8_fuse_litela(
...
@@ -160,7 +161,7 @@ void gemm_w8a8_fuse_litela(
bool,
bool,
bool>;
bool>;
checkCUDA(
cuda
FuncSetAttribute(func,
cuda
FuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
checkCUDA(
hip
FuncSetAttribute(func,
hip
FuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
...
@@ -169,7 +170,7 @@ void gemm_w8a8_fuse_litela(
...
@@ -169,7 +170,7 @@ void gemm_w8a8_fuse_litela(
std::swap(grid.x, grid.y);
std::swap(grid.x, grid.y);
}
}
func<<<
grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE
>>>(
hipLaunchKernelGGL(( func), dim3(
grid
)
,
dim3(
GEMM::WARP_SIZE * GEMM::NUM_WARPS
)
, Epilogue::SHMEM_SIZE
, 0,
act.data_ptr<GEMM::packed_act_t>(),
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
...
@@ -179,14 +180,14 @@ void gemm_w8a8_fuse_litela(
...
@@ -179,14 +180,14 @@ void gemm_w8a8_fuse_litela(
swapBlockMN,
swapBlockMN,
false
false
);
);
checkCUDA(
cuda
GetLastError());
checkCUDA(
hip
GetLastError());
invoke_kernel<Epilogue::vk_mul_q_kernel>
<<<
dim3(batch_m / 128, num_heads, batch_size),
128>>>(
hipLaunchKernelGGL((
invoke_kernel<Epilogue::vk_mul_q_kernel>
), dim3(
dim3(batch_m / 128, num_heads, batch_size)
)
,
dim3(128), 0, 0,
out_q.data_ptr<GEMM::half_t>(),
out_q.data_ptr<GEMM::half_t>(),
out_vk.data_ptr<float>(),
out_vk.data_ptr<float>(),
1e-6f
1e-6f
);
);
checkCUDA(
cuda
GetLastError());
checkCUDA(
hip
GetLastError());
}
}
#endif
#endif
...
...
src/kernels/zgemm/mma.cuh
View file @
1a8114bf
...
@@ -35,7 +35,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
...
@@ -35,7 +35,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
uint2
d
;
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%2, %3, %4, %5},"
...
@@ -66,7 +66,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2
...
@@ -66,7 +66,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2
template
<
bool
is_bf16
>
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
...
@@ -110,7 +110,7 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b,
...
@@ -110,7 +110,7 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b,
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
...
...
src/kernels/zgemm/mma_earlycuda.cuh
View file @
1a8114bf
...
@@ -36,7 +36,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
...
@@ -36,7 +36,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
uint2
d
;
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%2, %3, %4, %5},"
...
@@ -67,7 +67,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2
...
@@ -67,7 +67,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2
template
<
bool
is_bf16
>
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
template
<
>
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
...
@@ -85,7 +85,7 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2
...
@@ -85,7 +85,7 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2
template
<
>
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
...
@@ -121,7 +121,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
...
@@ -121,7 +121,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
64
;
static
constexpr
int
K
=
64
;
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
...
@@ -175,7 +175,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
...
@@ -175,7 +175,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
64
;
static
constexpr
int
K
=
64
;
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
...
...
src/pytorch_compat.h
View file @
1a8114bf
#include "hip/hip_runtime.h"
#pragma once
#pragma once
#include "common.h"
#include "common.h"
...
@@ -9,7 +10,7 @@ inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
...
@@ -9,7 +10,7 @@ inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
void
C10_
CUDA
_CHECK
(
T
ret
)
{
inline
void
C10_
HIP
_CHECK
(
T
ret
)
{
return
checkCUDA
(
ret
);
return
checkCUDA
(
ret
);
}
}
...
@@ -34,16 +35,16 @@ namespace cuda {
...
@@ -34,16 +35,16 @@ namespace cuda {
using
::
getCurrentDeviceProperties
;
using
::
getCurrentDeviceProperties
;
struct
StreamWrapper
{
struct
StreamWrapper
{
cuda
Stream_t
st
;
hip
Stream_t
st
;
cuda
Stream_t
stream
()
const
{
hip
Stream_t
stream
()
const
{
return
st
;
return
st
;
}
}
};
};
inline
StreamWrapper
getCurrent
CUDA
Stream
()
{
inline
StreamWrapper
getCurrent
HIP
Stream
MasqueradingAsCUDA
()
{
return
StreamWrapper
(
::
getCurrent
CUDA
Stream
());
return
StreamWrapper
(
::
getCurrent
HIP
Stream
MasqueradingAsCUDA
());
}
}
struct
CUDA
Guard
{
struct
HIP
Guard
MasqueradingAsCUDA
{
int
dev
;
int
dev
;
};
};
...
...
Prev
1
2
3
Next
fengzch-das
@Fzc7075
mentioned in commit
0a7c8614
·
Nov 21, 2025
mentioned in commit
0a7c8614
mentioned in commit 0a7c8614d9f6cd854550424762fe5d1d9372cefa
Toggle commit list
fengzch-das
@Fzc7075
mentioned in merge request
!1 (merged)
·
Nov 21, 2025
mentioned in merge request
!1 (merged)
mentioned in merge request !1
Toggle commit list
fengzch-das
@Fzc7075
mentioned in commit
7c282e2e
·
Nov 21, 2025
mentioned in commit
7c282e2e
mentioned in commit 7c282e2e9f2302b856ced94c34b867fb45ad291d
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment