Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
9316940c
Commit
9316940c
authored
Nov 21, 2025
by
fengzch
Browse files
fix: compile misc_kernels.cu complete
parent
038b8469
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
5 deletions
+18
-5
src/kernels/utils.cuh
src/kernels/utils.cuh
+7
-1
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+1
-1
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+10
-3
No files found.
src/kernels/utils.cuh
View file @
9316940c
...
...
@@ -225,6 +225,11 @@ __device__ inline T_OUT cuda_cast(T_IN val) {
return
val
;
}
template
<
>
__device__
inline
__hip_bfloat16
cuda_cast
<
__hip_bfloat16
,
long
>
(
long
val
)
{
return
(
long
long
)
val
;
}
template
<
>
__device__
inline
float2
cuda_cast
<
float2
,
int2
>
(
int2
val
)
{
return
make_float2
(
val
.
x
,
val
.
y
);
...
...
@@ -268,7 +273,8 @@ __device__ inline int8_t cuda_cast<int8_t, half>(half val) {
};
fp16
=
val
;
asm
volatile
(
"cvt.rni.sat.s8.f16 %0, %1;"
:
"=h"
(
int16
)
:
"h"
(
int16_in
));
int16
=
int16_in
;
// asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return
int8
[
0
];
}
...
...
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
9316940c
...
...
@@ -101,7 +101,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
checkCUDA
(
cudaFuncSetAttribute
(
reinterpret_cast
<
const
void
*>
(
func
)
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
alpha
==
1.0
f
);
...
...
src/kernels/zgemm/gemm_w8a8.cuh
View file @
9316940c
...
...
@@ -204,7 +204,14 @@ public:
#pragma unroll
for
(
int
mask
=
32
/
2
;
mask
>
0
;
mask
/=
2
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__shfl_xor
(
maxvalue2
,
mask
));
__half2
m
;
m
.
x
=
float
(
maxvalue2
.
x
);
m
.
y
=
float
(
maxvalue2
.
y
);
auto
temp
=
__shfl_xor
(
m
,
mask
);
__hip_bfloat162
n
;
n
.
x
=
float
(
temp
.
x
);
n
.
y
=
float
(
temp
.
y
);
maxvalue2
=
__hmax2
(
maxvalue2
,
n
);
}
return
__hmax
(
maxvalue2
.
x
,
maxvalue2
.
y
);
...
...
@@ -243,9 +250,9 @@ public:
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
gemmWarpId
=
blockIdx
.
x
%
(
BLOCK_M
/
WARP_M
);
__shared__
align
as
(
128
)
half_t
oscale_shmem
[
WARP_M
];
__shared__
__attribute__
((
align
ed
(
128
)
))
half_t
oscale_shmem
[
WARP_M
];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__
align
as
(
128
)
uint8_t
tmp_shmem
[
NUM_WARPS
][
512
];
__shared__
__attribute__
((
align
ed
(
128
)
))
uint8_t
tmp_shmem
[
NUM_WARPS
][
512
];
const
int
K2
=
fuse_glu
?
K
/
2
:
K
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment