Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
0a7c8614
Commit
0a7c8614
authored
Nov 21, 2025
by
fengzch-das
Browse files
Revert "hipify code"
This reverts commit
1a8114bf
parent
1a8114bf
Pipeline
#3050
failed with stages
in 0 seconds
Changes
50
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
50 additions
and
55 deletions
+50
-55
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
+0
-0
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
+0
-0
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
+0
-0
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+17
-18
src/kernels/zgemm/gemm_w4a4_test.cu
src/kernels/zgemm/gemm_w4a4_test.cu
+6
-7
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+12
-13
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+1
-2
src/kernels/zgemm/mma.cuh
src/kernels/zgemm/mma.cuh
+3
-3
src/kernels/zgemm/mma_earlycuda.cuh
src/kernels/zgemm/mma_earlycuda.cuh
+5
-5
src/pytorch_compat.h
src/pytorch_compat.h
+6
-7
No files found.
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.
hip
→
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.
cu
View file @
0a7c8614
File moved
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.
hip
→
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.
cu
View file @
0a7c8614
File moved
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.
hip
→
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.
cu
View file @
0a7c8614
File moved
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
0a7c8614
#include "hip/hip_runtime.h"
#include "gemm_w4a4_launch.cuh"
namespace
nunchaku
::
kernels
{
...
...
@@ -85,7 +84,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// (test_sizeof<decltype(args)>(), ...);
// }, args);
// constexpr bool FP4_AVAILABLE = __
DTK
_ARCH__ >= 1200;
// constexpr bool FP4_AVAILABLE = __
CUDA
_ARCH__ >= 1200;
if
constexpr
(
!
USE_FP4
)
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
...
...
@@ -102,12 +101,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
hip
FuncSetAttribute
(
func
,
hip
FuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
alpha
==
1.0
f
);
hipLaunchKernelGGL
((
func
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
)
,
shmem
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrent
CUDA
Stream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
...
...
@@ -118,7 +117,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
args
,
swapBlockMN
,
false
);
checkCUDA
(
hip
GetLastError
());
checkCUDA
(
cuda
GetLastError
());
});
return
;
}
...
...
@@ -141,13 +140,13 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
hip
FuncSetAttribute
(
func
,
hip
FuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
hipLaunchKernelGGL
((
func
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
)
,
shmem
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrent
CUDA
Stream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
...
...
@@ -159,7 +158,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
args
,
swapBlockMN
,
false
);
checkCUDA
(
hip
GetLastError
());
checkCUDA
(
cuda
GetLastError
());
});
return
;
...
...
@@ -442,10 +441,10 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
BLOCK_SIZE
=
128
;
}
hipLaunchKernelGGL
((
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
>
)
,
dim3
(
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
)
)
,
dim3
(
BLOCK_SIZE
)
,
0
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
>
<<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
,
0
,
getCurrent
CUDA
Stream
()
>>>
(
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
checkCUDA
(
hip
GetLastError
());
checkCUDA
(
cuda
GetLastError
());
}
template
<
typename
Config
,
bool
USE_FP4
>
...
...
@@ -496,12 +495,12 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
hip
FuncSetAttribute
(
func
,
hip
FuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
hipLaunchKernelGGL
((
func
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
)
,
kernel
::
SHMEM_SIZE
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrent
CUDA
Stream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
...
...
@@ -516,7 +515,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
.
actualN
=
actualN
,
.
alwaysfalse
=
false
,
});
checkCUDA
(
hip
GetLastError
());
checkCUDA
(
cuda
GetLastError
());
});
// });
}
...
...
@@ -540,9 +539,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
assert
(
oscales
.
numel
()
==
M
*
K
/
GEMM
::
WARP_K
);
dim3
grid
(
M
/
GEMM
::
WARP_M
,
K
/
GEMM
::
WARP_K
);
hipLaunchKernelGGL
((
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
>
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
)
,
0
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
>
<<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrent
CUDA
Stream
()
>>>
(
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
checkCUDA
(
hip
GetLastError
());
checkCUDA
(
cuda
GetLastError
());
}
template
<
typename
Config
,
bool
USE_FP4
>
...
...
@@ -565,9 +564,9 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
assert
(
oscales
.
numel
()
==
N
*
K
/
GEMM
::
WARP_K
);
dim3
grid
(
N
/
GEMM
::
WARP_N
,
K
/
GEMM
::
WARP_K
);
hipLaunchKernelGGL
((
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
>
),
dim3
(
grid
)
,
dim3
(
GEMM
::
WARP_SIZE
)
,
0
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
>
<<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrent
CUDA
Stream
()
>>>
(
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
checkCUDA
(
hip
GetLastError
());
checkCUDA
(
cuda
GetLastError
());
}
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w4a4_test.
hip
→
src/kernels/zgemm/gemm_w4a4_test.
cu
View file @
0a7c8614
#include "hip/hip_runtime.h"
#include "zgemm.h"
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
...
...
@@ -22,11 +21,11 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA(
hip
FuncSetAttribute(func,
hip
FuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
hipLaunchKernelGGL(( func), dim3(
grid
)
,
dim3(
GEMM::WARP_SIZE * GEMM::NUM_WARPS
)
, kernel::SHMEM_SIZE, getCurrent
HIP
Stream
MasqueradingAsCUDA(),
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrent
CUDA
Stream
()
>>>
(
typename
kernel
::
Arguments
{.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
...
...
@@ -39,7 +38,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
GEMM
::
half_t
>
(),
.
epsilon
=
1e-6
,
}});
checkCUDA(
hip
GetLastError());
checkCUDA
(
cuda
GetLastError
());
}
void
test_pack_qkv
(
Tensor
input
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
...
...
@@ -60,11 +59,11 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA(
hip
FuncSetAttribute(func,
hip
FuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
hipLaunchKernelGGL(( func), dim3(
grid
)
,
dim3(
GEMM::WARP_SIZE * GEMM::NUM_WARPS
)
, kernel::SHMEM_SIZE, getCurrent
HIP
Stream
MasqueradingAsCUDA(),
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrent
CUDA
Stream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
...
...
@@ -84,7 +83,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}});
checkCUDA(
hip
GetLastError());
checkCUDA
(
cuda
GetLastError
());
}
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w8a8.
hip
→
src/kernels/zgemm/gemm_w8a8.
cu
View file @
0a7c8614
#include "hip/hip_runtime.h"
#include "zgemm.h"
#include "gemm_w8a8.cuh"
...
...
@@ -27,14 +26,14 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
checkCUDA(
hip
FuncSetAttribute(func,
hip
FuncAttributeMaxDynamicSharedMemorySize, 92160));
checkCUDA
(
cuda
FuncSetAttribute
(
func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
92160
));
hipLaunchKernelGGL(( func), dim3(
grid
)
,
dim3(
block
)
, kernel::smemSize(M, K)
, 0,
input.data_ptr<GEMM::half_t>(),
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
oscales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
K
,
false
);
checkCUDA(
hip
GetLastError());
checkCUDA
(
cuda
GetLastError
());
};
if
(
fuse_glu
)
{
...
...
@@ -75,8 +74,8 @@ void gemm_w8a8(Tensor act, // [M, K]
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
hipLaunchKernelGGL((
invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>
)
, dim3(
grid
)
,
dim3(
GEMM::WARP_SIZE * GEMM::NUM_WARPS
), 0, 0,
act.data_ptr<GEMM::packed_act_t>(),
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>>
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
ascales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
wscales
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
...
...
@@ -87,7 +86,7 @@ void gemm_w8a8(Tensor act, // [M, K]
args
,
swapBlockMN
,
false
);
checkCUDA(
hip
GetLastError());
checkCUDA
(
cuda
GetLastError
());
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
...
...
@@ -148,7 +147,7 @@ void gemm_w8a8_fuse_litela(
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_vk = out_vk.data_ptr<float>();
checkCUDA(
hip
MemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
checkCUDA(
cuda
MemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
...
...
@@ -161,7 +160,7 @@ void gemm_w8a8_fuse_litela(
bool,
bool>;
checkCUDA(
hip
FuncSetAttribute(func,
hip
FuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
checkCUDA(
cuda
FuncSetAttribute(func,
cuda
FuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
...
...
@@ -170,7 +169,7 @@ void gemm_w8a8_fuse_litela(
std::swap(grid.x, grid.y);
}
hipLaunchKernelGGL(( func), dim3(
grid
)
,
dim3(
GEMM::WARP_SIZE * GEMM::NUM_WARPS
)
, Epilogue::SHMEM_SIZE
, 0,
func<<<
grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE
>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
...
...
@@ -180,14 +179,14 @@ void gemm_w8a8_fuse_litela(
swapBlockMN,
false
);
checkCUDA(
hip
GetLastError());
checkCUDA(
cuda
GetLastError());
hipLaunchKernelGGL((
invoke_kernel<Epilogue::vk_mul_q_kernel>
), dim3(
dim3(batch_m / 128, num_heads, batch_size)
)
,
dim3(128), 0, 0,
invoke_kernel<Epilogue::vk_mul_q_kernel>
<<<
dim3(batch_m / 128, num_heads, batch_size),
128>>>(
out_q.data_ptr<GEMM::half_t>(),
out_vk.data_ptr<float>(),
1e-6f
);
checkCUDA(
hip
GetLastError());
checkCUDA(
cuda
GetLastError());
}
#endif
...
...
src/kernels/zgemm/gemm_w8a8.cuh
View file @
0a7c8614
#include "hip/hip_runtime.h"
#pragma once
#include "gemm_base.cuh"
...
...
@@ -439,7 +438,7 @@ public:
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template
<
typename
Epilogue
>
struct
gemm_w8a8_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__
hip
_bfloat16
>
?
800
:
750
;
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__
nv
_bfloat16
>
?
800
:
750
;
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
...
...
src/kernels/zgemm/mma.cuh
View file @
0a7c8614
...
...
@@ -35,7 +35,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
...
...
@@ -66,7 +66,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
...
...
@@ -110,7 +110,7 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b,
uint4
d
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
...
...
src/kernels/zgemm/mma_earlycuda.cuh
View file @
0a7c8614
...
...
@@ -36,7 +36,7 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
...
...
@@ -67,7 +67,7 @@ __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
...
...
@@ -85,7 +85,7 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
...
...
@@ -121,7 +121,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
uint4
d
;
static
constexpr
int
K
=
64
;
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
...
...
@@ -175,7 +175,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
uint4
d
;
static
constexpr
int
K
=
64
;
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
...
...
src/pytorch_compat.h
View file @
0a7c8614
#include "hip/hip_runtime.h"
#pragma once
#include "common.h"
...
...
@@ -10,7 +9,7 @@ inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
}
template
<
typename
T
>
inline
void
C10_
HIP
_CHECK
(
T
ret
)
{
inline
void
C10_
CUDA
_CHECK
(
T
ret
)
{
return
checkCUDA
(
ret
);
}
...
...
@@ -35,16 +34,16 @@ namespace cuda {
using
::
getCurrentDeviceProperties
;
struct
StreamWrapper
{
hip
Stream_t
st
;
hip
Stream_t
stream
()
const
{
cuda
Stream_t
st
;
cuda
Stream_t
stream
()
const
{
return
st
;
}
};
inline
StreamWrapper
getCurrent
HIP
Stream
MasqueradingAsCUDA
()
{
return
StreamWrapper
(
::
getCurrent
HIP
Stream
MasqueradingAsCUDA
());
inline
StreamWrapper
getCurrent
CUDA
Stream
()
{
return
StreamWrapper
(
::
getCurrent
CUDA
Stream
());
}
struct
HIP
Guard
MasqueradingAsCUDA
{
struct
CUDA
Guard
{
int
dev
;
};
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment