Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
fcc551cb
Commit
fcc551cb
authored
Mar 26, 2025
by
sxtyzhangzk
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
[major] fix compilation error
WHO THE HELL INVENTED ADL?
parent
9c92fe81
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
37 deletions
+37
-37
src/kernels/zgemm/attention.cuh
src/kernels/zgemm/attention.cuh
+6
-6
src/kernels/zgemm/gemm_base.cuh
src/kernels/zgemm/gemm_base.cuh
+14
-14
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+3
-3
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+14
-14
No files found.
src/kernels/zgemm/attention.cuh
View file @
fcc551cb
...
@@ -126,12 +126,12 @@ public:
...
@@ -126,12 +126,12 @@ public:
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
results
[
i
]
=
float22half2
<
half2_t
>
(
float2
(
input
.
data
[
i
*
2
],
input
.
data
[
i
*
2
+
1
]));
results
[
i
]
=
float22half2
<
half2_t
>
(
float2
(
input
.
data
[
i
*
2
],
input
.
data
[
i
*
2
+
1
]));
}
}
return
bit_cast
<
packed_fpsum_t
>
(
results
);
return
kernels
::
bit_cast
<
packed_fpsum_t
>
(
results
);
}
}
__device__
__forceinline__
__device__
__forceinline__
static
packed_f32psum_t
packed_fp16_to_fp32
(
packed_fpsum_t
input
)
{
static
packed_f32psum_t
packed_fp16_to_fp32
(
packed_fpsum_t
input
)
{
auto
arr
=
bit_cast
<
std
::
array
<
half2_t
,
4
>>
(
input
);
auto
arr
=
kernels
::
bit_cast
<
std
::
array
<
half2_t
,
4
>>
(
input
);
packed_f32psum_t
results
;
packed_f32psum_t
results
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
tmp
=
half22float2
(
arr
[
i
]);
float2
tmp
=
half22float2
(
arr
[
i
]);
...
@@ -214,10 +214,10 @@ public:
...
@@ -214,10 +214,10 @@ public:
__device__
__forceinline__
__device__
__forceinline__
static
packed_fpsum_t
fix_nan
(
packed_fpsum_t
input
)
{
static
packed_fpsum_t
fix_nan
(
packed_fpsum_t
input
)
{
input
.
x
=
bit_cast
<
int
>
(
fix_nan
(
bit_cast
<
half2_t
>
(
input
.
x
)));
input
.
x
=
kernels
::
bit_cast
<
int
>
(
fix_nan
(
kernels
::
bit_cast
<
half2_t
>
(
input
.
x
)));
input
.
y
=
bit_cast
<
int
>
(
fix_nan
(
bit_cast
<
half2_t
>
(
input
.
y
)));
input
.
y
=
kernels
::
bit_cast
<
int
>
(
fix_nan
(
kernels
::
bit_cast
<
half2_t
>
(
input
.
y
)));
input
.
z
=
bit_cast
<
int
>
(
fix_nan
(
bit_cast
<
half2_t
>
(
input
.
z
)));
input
.
z
=
kernels
::
bit_cast
<
int
>
(
fix_nan
(
kernels
::
bit_cast
<
half2_t
>
(
input
.
z
)));
input
.
w
=
bit_cast
<
int
>
(
fix_nan
(
bit_cast
<
half2_t
>
(
input
.
w
)));
input
.
w
=
kernels
::
bit_cast
<
int
>
(
fix_nan
(
kernels
::
bit_cast
<
half2_t
>
(
input
.
w
)));
return
input
;
return
input
;
}
}
...
...
src/kernels/zgemm/gemm_base.cuh
View file @
fcc551cb
...
@@ -206,21 +206,21 @@ public:
...
@@ -206,21 +206,21 @@ public:
static
constexpr
bool
is_bf16
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
;
static
constexpr
bool
is_bf16
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
;
uint4
out1
=
mma_m16n8k16_f32f16f16f32
<
is_bf16
>
(
uint4
out1
=
mma_m16n8k16_f32f16f16f32
<
is_bf16
>
(
bit_cast
<
uint4
>
(
a
),
kernels
::
bit_cast
<
uint4
>
(
a
),
bit_cast
<
uint2
>
(
std
::
array
<
half2_t
,
2
>
(
b
.
data
[
0
],
b
.
data
[
1
])),
kernels
::
bit_cast
<
uint2
>
(
std
::
array
<
half2_t
,
2
>
(
b
.
data
[
0
],
b
.
data
[
1
])),
bit_cast
<
uint4
>
(
float4
(
psum
.
data
[
0
],
psum
.
data
[
1
],
psum
.
data
[
2
],
psum
.
data
[
3
])));
kernels
::
bit_cast
<
uint4
>
(
float4
(
psum
.
data
[
0
],
psum
.
data
[
1
],
psum
.
data
[
2
],
psum
.
data
[
3
])));
uint4
out2
=
mma_m16n8k16_f32f16f16f32
<
is_bf16
>
(
uint4
out2
=
mma_m16n8k16_f32f16f16f32
<
is_bf16
>
(
bit_cast
<
uint4
>
(
a
),
kernels
::
bit_cast
<
uint4
>
(
a
),
bit_cast
<
uint2
>
(
std
::
array
<
half2_t
,
2
>
(
b
.
data
[
2
],
b
.
data
[
3
])),
kernels
::
bit_cast
<
uint2
>
(
std
::
array
<
half2_t
,
2
>
(
b
.
data
[
2
],
b
.
data
[
3
])),
bit_cast
<
uint4
>
(
float4
(
psum
.
data
[
4
],
psum
.
data
[
5
],
psum
.
data
[
6
],
psum
.
data
[
7
])));
kernels
::
bit_cast
<
uint4
>
(
float4
(
psum
.
data
[
4
],
psum
.
data
[
5
],
psum
.
data
[
6
],
psum
.
data
[
7
])));
psum
.
data
[
0
]
=
bit_cast
<
float
>
(
out1
.
x
);
psum
.
data
[
0
]
=
kernels
::
bit_cast
<
float
>
(
out1
.
x
);
psum
.
data
[
1
]
=
bit_cast
<
float
>
(
out1
.
y
);
psum
.
data
[
1
]
=
kernels
::
bit_cast
<
float
>
(
out1
.
y
);
psum
.
data
[
2
]
=
bit_cast
<
float
>
(
out1
.
z
);
psum
.
data
[
2
]
=
kernels
::
bit_cast
<
float
>
(
out1
.
z
);
psum
.
data
[
3
]
=
bit_cast
<
float
>
(
out1
.
w
);
psum
.
data
[
3
]
=
kernels
::
bit_cast
<
float
>
(
out1
.
w
);
psum
.
data
[
4
]
=
bit_cast
<
float
>
(
out2
.
x
);
psum
.
data
[
4
]
=
kernels
::
bit_cast
<
float
>
(
out2
.
x
);
psum
.
data
[
5
]
=
bit_cast
<
float
>
(
out2
.
y
);
psum
.
data
[
5
]
=
kernels
::
bit_cast
<
float
>
(
out2
.
y
);
psum
.
data
[
6
]
=
bit_cast
<
float
>
(
out2
.
z
);
psum
.
data
[
6
]
=
kernels
::
bit_cast
<
float
>
(
out2
.
z
);
psum
.
data
[
7
]
=
bit_cast
<
float
>
(
out2
.
w
);
psum
.
data
[
7
]
=
kernels
::
bit_cast
<
float
>
(
out2
.
w
);
return
psum
;
return
psum
;
}
}
...
...
src/kernels/zgemm/gemm_utils.cuh
View file @
fcc551cb
...
@@ -573,7 +573,7 @@ static half2 int2half2_fast_8192(int x, int y) {
...
@@ -573,7 +573,7 @@ static half2 int2half2_fast_8192(int x, int y) {
ival
=
ival
>>
4
;
ival
=
ival
>>
4
;
// (val & 0x03FF03FF) ^ 0x76007600
// (val & 0x03FF03FF) ^ 0x76007600
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x76007600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x76007600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
bit_cast
<
half2
>
(
hval
),
half2
(
-
24576.0
f
,
-
24576.0
f
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
24576.0
f
,
-
24576.0
f
));
}
}
// val in [-4096, 4095], steps of 8, round to nearest
// val in [-4096, 4095], steps of 8, round to nearest
__device__
__forceinline__
__device__
__forceinline__
...
@@ -590,7 +590,7 @@ static half2 int2half2_fast_4096_rn(int x, int y) {
...
@@ -590,7 +590,7 @@ static half2 int2half2_fast_4096_rn(int x, int y) {
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x7632
));
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x7632
));
// (val & 0x03FF03FF) ^ 0x72007200
// (val & 0x03FF03FF) ^ 0x72007200
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x72007200
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x72007200
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
bit_cast
<
half2
>
(
hval
),
half2
(
-
12288.0
f
,
-
12288.0
f
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
12288.0
f
,
-
12288.0
f
));
}
}
// val in [-512, 511]
// val in [-512, 511]
__device__
__forceinline__
__device__
__forceinline__
...
@@ -602,7 +602,7 @@ static half2 int2half2_fast_512(int x, int y) {
...
@@ -602,7 +602,7 @@ static half2 int2half2_fast_512(int x, int y) {
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
// (val & 0x03FF03FF) ^ 0x66006600
// (val & 0x03FF03FF) ^ 0x66006600
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x66006600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x66006600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
bit_cast
<
half2
>
(
hval
),
half2
(
-
1536.0
f
,
-
1536.0
f
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
1536.0
f
,
-
1536.0
f
));
}
}
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4.cuh
View file @
fcc551cb
...
@@ -1674,7 +1674,7 @@ public:
...
@@ -1674,7 +1674,7 @@ public:
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
// lane 0-15: n*16+0, lane 16-31: n*16+8
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
// lane 0-15: n*16+0, lane 16-31: n*16+8
uint4
tmp
;
uint4
tmp
;
ldmatrix
(
shmem
+
col
,
tmp
);
ldmatrix
(
shmem
+
col
,
tmp
);
return
bit_cast
<
packed_fpsum_t
>
(
tmp
);
return
kernels
::
bit_cast
<
packed_fpsum_t
>
(
tmp
);
}
}
__device__
__forceinline__
__device__
__forceinline__
...
@@ -1813,30 +1813,30 @@ public:
...
@@ -1813,30 +1813,30 @@ public:
__device__
__forceinline__
__device__
__forceinline__
static
packed_qkv_t
pack_q
(
packed_fpsum_t
input
)
{
static
packed_qkv_t
pack_q
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
packed_qkv_t
output
;
output
.
x
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
y
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
z
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
w
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
return
output
;
return
output
;
}
}
__device__
__forceinline__
__device__
__forceinline__
static
packed_qkv_t
pack_k
(
packed_fpsum_t
input
)
{
static
packed_qkv_t
pack_k
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
packed_qkv_t
output
;
output
.
x
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
y
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
z
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
w
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
return
output
;
return
output
;
}
}
__device__
__forceinline__
__device__
__forceinline__
static
packed_qkv_t
pack_v
(
packed_fpsum_t
input
)
{
static
packed_qkv_t
pack_v
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
packed_qkv_t
output
;
output
.
x
=
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
0
])));
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
0
])));
output
.
y
=
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
1
])));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
1
])));
output
.
z
=
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
2
])));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
2
])));
output
.
w
=
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
3
])));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
3
])));
return
output
;
return
output
;
}
}
...
@@ -1867,7 +1867,7 @@ public:
...
@@ -1867,7 +1867,7 @@ public:
unrolled_loop
<
WARP_M_TILES
>
([
&
]
<
int
m
>
()
ALWAYSINLINE
{
unrolled_loop
<
WARP_M_TILES
>
([
&
]
<
int
m
>
()
ALWAYSINLINE
{
unrolled_loop
<
WARP_N_TILES
>
([
&
]
<
int
n
>
()
ALWAYSINLINE
{
unrolled_loop
<
WARP_N_TILES
>
([
&
]
<
int
n
>
()
ALWAYSINLINE
{
packed_qkv_t
pack
=
funcPack
(
fpsum
[
m
*
WARP_N_TILES
+
n
]);
packed_qkv_t
pack
=
funcPack
(
fpsum
[
m
*
WARP_N_TILES
+
n
]);
mask
(
pack
,
bit_cast
<
uint32_t
>
(
maskVal
),
m
,
maxRows
-
warpId
*
WARP_M
);
mask
(
pack
,
kernels
::
bit_cast
<
uint32_t
>
(
maskVal
),
m
,
maxRows
-
warpId
*
WARP_M
);
store
(
&
ptrlane
[(
m
*
WARP_N_TILES
+
n
)
*
WARP_SIZE
],
pack
);
store
(
&
ptrlane
[(
m
*
WARP_N_TILES
+
n
)
*
WARP_SIZE
],
pack
);
});
});
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment