Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
54241df6
Commit
54241df6
authored
Nov 21, 2025
by
fengzch
Browse files
fix: compile gemm_w8a8.cu complete
parent
2cb9a2c7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
8 deletions
+8
-8
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+7
-7
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+1
-1
No files found.
src/kernels/zgemm/gemm_utils.cuh
View file @
54241df6
...
@@ -21,14 +21,14 @@ __device__ __forceinline__ static T load(const T *addr) {
...
@@ -21,14 +21,14 @@ __device__ __forceinline__ static T load(const T *addr) {
uint2
data
;
uint2
data
;
asm
volatile
(
"ld.shared.v2.b32 {%0, %1}, [%2];"
asm
volatile
(
"ld.shared.v2.b32 {%0, %1}, [%2];"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
__cvta_generic_to_shared
(
addr
)));
:
"l"
((
addr
)));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
;
uint4
data
;
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
__cvta_generic_to_shared
(
addr
)));
:
"l"
((
addr
)));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
}
return
*
addr
;
return
*
addr
;
...
@@ -89,12 +89,12 @@ __device__ __forceinline__ static void store(T *addr, T val) {
...
@@ -89,12 +89,12 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if
constexpr
(
sizeof
(
T
)
==
8
)
{
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
asm
volatile
(
asm
volatile
(
"st.shared.v2.b32 [%0], {%1, %2};"
::
"l"
(
__cvta_generic_to_shared
(
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
"st.shared.v2.b32 [%0], {%1, %2};"
::
"l"
((
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
return
;
return
;
}
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};"
::
"l"
(
__cvta_generic_to_shared
(
addr
)),
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};"
::
"l"
((
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
z
),
...
@@ -192,9 +192,9 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
...
@@ -192,9 +192,9 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
}
}
__device__
__forceinline__
static
void
ldmatrix
(
const
void
*
ptr
,
uint4
&
out
)
{
__device__
__forceinline__
static
void
ldmatrix
(
const
void
*
ptr
,
uint4
&
out
)
{
//
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
//
: "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
:
"=r"
(
out
.
x
),
"=r"
(
out
.
y
),
"=r"
(
out
.
z
),
"=r"
(
out
.
w
)
//
: "l"(
__cvta_generic_to_shared
(ptr))); // limengmeng
:
"l"
((
ptr
)));
// limengmeng
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
src/kernels/zgemm/gemm_w8a8.cu
View file @
54241df6
...
@@ -26,7 +26,7 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
...
@@ -26,7 +26,7 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
auto
func
=
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
92160
));
checkCUDA
(
cudaFuncSetAttribute
(
reinterpret_cast
<
const
void
*>
(
func
)
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
92160
));
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment