Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
712e1834
Commit
712e1834
authored
Oct 03, 2023
by
Casper
Browse files
Standardize asm volatile (Windows Support)
parent
7c59407c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
awq_cuda/quantization/gemm_cuda_gen.cu
awq_cuda/quantization/gemm_cuda_gen.cu
+12
-12
No files found.
awq_cuda/quantization/gemm_cuda_gen.cu
View file @
712e1834
...
...
@@ -189,7 +189,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
for
(
int
j_0_4
=
0
;
j_0_4
<
4
;
++
j_0_4
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
...
@@ -197,7 +197,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
...
@@ -205,7 +205,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
...
@@ -213,7 +213,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
...
@@ -221,7 +221,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
#else
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
...
@@ -229,7 +229,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
...
@@ -420,7 +420,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
...
@@ -428,7 +428,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
...
@@ -436,7 +436,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
...
@@ -444,7 +444,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
...
@@ -452,7 +452,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
#else
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
...
@@ -460,7 +460,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment