Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
60a0d6d7
Commit
60a0d6d7
authored
Mar 27, 2025
by
sxtyzhangzk
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Fix compile error on CUDA 12.4-
parent
fcc551cb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
463 additions
and
177 deletions
+463
-177
src/kernels/zgemm/gemm_base.cuh
src/kernels/zgemm/gemm_base.cuh
+1
-0
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+0
-177
src/kernels/zgemm/mma.cuh
src/kernels/zgemm/mma.cuh
+189
-0
src/kernels/zgemm/mma_earlycuda.cuh
src/kernels/zgemm/mma_earlycuda.cuh
+273
-0
No files found.
src/kernels/zgemm/gemm_base.cuh
View file @
60a0d6d7
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "../utils.cuh"
#include "../utils.cuh"
#include "../dispatch_utils.h"
#include "../dispatch_utils.h"
#include "gemm_utils.cuh"
#include "gemm_utils.cuh"
#include "mma_earlycuda.cuh"
#pragma nv_diag_suppress 177
#pragma nv_diag_suppress 177
...
...
src/kernels/zgemm/gemm_utils.cuh
View file @
60a0d6d7
...
@@ -195,183 +195,6 @@ static T movmatrix(T x) {
...
@@ -195,183 +195,6 @@ static T movmatrix(T x) {
return
x
;
return
x
;
}
}
namespace
mma_helper
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#endif
return
d
;
}
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"C"
(
mma_helper
::
f16bf16
<
is_bf16
>::
value
)
);
#else
static_assert
(
!
is_bf16
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
#endif
return
d
;
}
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
"C"
(
AType
::
value
),
"C"
(
BType
::
value
)
);
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
)
"C"
(
AType
::
value
),
"C"
(
BType
::
value
)
);
#endif
return
d
;
}
// x in low bit, y in high bit
// x in low bit, y in high bit
template
<
int
bitwidth
,
bool
use_unsigned
>
template
<
int
bitwidth
,
bool
use_unsigned
>
__device__
__forceinline__
__device__
__forceinline__
...
...
src/kernels/zgemm/mma.cuh
0 → 100644
View file @
60a0d6d7
#pragma once
#include <cstdint>
#include "common.h"
// only supports cuda 12.5+
namespace
nunchaku
::
kernels
{
namespace
mma_helper
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#endif
return
d
;
}
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"C"
(
mma_helper
::
f16bf16
<
is_bf16
>::
value
)
);
#else
static_assert
(
!
is_bf16
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
#endif
return
d
;
}
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
)
);
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
)
);
#endif
return
d
;
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/mma_earlycuda.cuh
0 → 100644
View file @
60a0d6d7
#pragma once
#include <cstdint>
#include "common.h"
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
namespace
nunchaku
::
kernels
{
namespace
mma_helper
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#endif
return
d
;
}
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
return
d
;
}
#endif
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
#endif
return
d
;
}
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
template
<
>
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
s4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
64
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
);
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
)
);
#endif
return
d
;
}
template
<
>
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
u4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
64
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
);
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
)
);
#endif
return
d
;
}
};
// namespace nunchaku::kernels
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment