Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
7c59407c
Commit
7c59407c
authored
Oct 03, 2023
by
Casper Hansen
Browse files
Turing support (initial)
parent
eccb8f9c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
109 additions
and
9 deletions
+109
-9
awq_cuda/quantization/gemm_cuda_gen.cu
awq_cuda/quantization/gemm_cuda_gen.cu
+104
-4
setup.py
setup.py
+5
-5
No files found.
awq_cuda/quantization/gemm_cuda_gen.cu
View file @
7c59407c
...
@@ -187,8 +187,41 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
...
@@ -187,8 +187,41 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
}
}
}
for
(
int
j_0_4
=
0
;
j_0_4
<
4
;
++
j_0_4
)
{
for
(
int
j_0_4
=
0
;
j_0_4
<
4
;
++
j_0_4
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
{
{
asm
volatile
(
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
#else
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
@@ -196,12 +229,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
...
@@ -196,12 +229,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
}
{
{
asm
volatile
(
__
asm
__
__
volatile
__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
}
#endif
}
}
}
}
}
}
...
@@ -384,9 +418,41 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
...
@@ -384,9 +418,41 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
for
(
int
j_0_4
=
0
;
j_0_4
<
2
;
++
j_0_4
)
for
(
int
j_0_4
=
0
;
j_0_4
<
2
;
++
j_0_4
)
{
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
{
asm
volatile
(
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
#else
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
@@ -394,12 +460,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
...
@@ -394,12 +460,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
}
{
{
asm
volatile
(
__
asm
__
__
volatile
__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
}
#endif
}
}
}
}
}
}
...
@@ -588,7 +655,39 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
...
@@ -588,7 +655,39 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
for
(
int
i_0_3
=
0
;
i_0_3
<
4
;
++
i_0_3
)
{
for
(
int
i_0_3
=
0
;
i_0_3
<
4
;
++
i_0_3
)
{
for
(
int
j_0_4
=
0
;
j_0_4
<
2
;
++
j_0_4
)
{
for
(
int
j_0_4
=
0
;
j_0_4
<
2
;
++
j_0_4
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
{
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
3
]));
}
{
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
3
]));
}
{
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
))))[
3
]));
}
{
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
i_0_3
*
16
)
+
(
j_0_4
*
8
)
+
4
)))[
3
]));
}
#else
{
{
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
...
@@ -604,6 +703,7 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
...
@@ -604,6 +703,7 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
:
"=f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
3
]));
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
(
i_0_3
*
8
)))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(((
i_0_3
*
16
)
+
(
j_0_4
*
8
))
+
4
)))[
3
]));
}
}
#endif
}
}
}
}
}
}
...
...
setup.py
View file @
7c59407c
...
@@ -74,15 +74,15 @@ def check_dependencies():
...
@@ -74,15 +74,15 @@ def check_dependencies():
def
get_compute_capabilities
():
def
get_compute_capabilities
():
# Collect the compute capabilities of all available GPUs.
# Collect the compute capabilities of all available GPUs.
compute_capabilities
=
set
()
for
i
in
range
(
torch
.
cuda
.
device_count
()):
for
i
in
range
(
torch
.
cuda
.
device_count
()):
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
i
)
if
major
<
8
:
cc
=
major
*
10
+
minor
raise
RuntimeError
(
"GPUs with compute capability less than 8.0 are not supported."
)
compute_capabilities
.
add
(
major
*
10
+
minor
)
if
cc
<
75
:
raise
RuntimeError
(
"GPUs with compute capability less than 7.5 are not supported."
)
# figure out compute capability
# figure out compute capability
compute_capabilities
=
{
80
,
86
,
89
,
90
}
compute_capabilities
=
{
75
,
80
,
86
,
89
,
90
}
capability_flags
=
[]
capability_flags
=
[]
for
cap
in
compute_capabilities
:
for
cap
in
compute_capabilities
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment