Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ_kernels
Commits
bad253e6
Commit
bad253e6
authored
Feb 14, 2024
by
Casper
Browse files
Windows support
parent
8907d182
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
awq_ext/quantization/gemm_cuda_gen.cu
awq_ext/quantization/gemm_cuda_gen.cu
+10
-10
No files found.
awq_ext/quantization/gemm_cuda_gen.cu
View file @
bad253e6
...
@@ -932,14 +932,14 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
...
@@ -932,14 +932,14 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
for
(
int
k_0_1
=
0
;
k_0_1
<
2
;
++
k_0_1
)
{
for
(
int
k_0_1
=
0
;
k_0_1
<
2
;
++
k_0_1
)
{
{
{
unsigned
int
addr
;
unsigned
int
addr
;
__
asm
__
__
volatile
__
(
asm
volatile
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
);
__
asm
__
__
volatile
__
(
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
...
@@ -950,12 +950,12 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
...
@@ -950,12 +950,12 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
for
(
int
ax1_0
=
0
;
ax1_0
<
N
/
32
;
++
ax1_0
)
{
for
(
int
ax1_0
=
0
;
ax1_0
<
N
/
32
;
++
ax1_0
)
{
{
{
unsigned
int
addr
;
unsigned
int
addr
;
__
asm
__
__
volatile
__
(
asm
volatile
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
(
N
*
16
+
128
))
+
(((
int
)
threadIdx
.
y
)
*
(
N
/
2
)))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
(
N
+
8
))
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
(
N
*
16
+
128
))
+
(((
int
)
threadIdx
.
y
)
*
(
N
/
2
)))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
(
N
+
8
))
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
);
__
asm
__
__
volatile
__
(
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
...
@@ -966,7 +966,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
...
@@ -966,7 +966,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
for
(
int
j_0_4
=
0
;
j_0_4
<
N
/
32
;
++
j_0_4
)
{
for
(
int
j_0_4
=
0
;
j_0_4
<
N
/
32
;
++
j_0_4
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
{
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
@@ -974,7 +974,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
...
@@ -974,7 +974,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
}
}
{
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
@@ -982,7 +982,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
...
@@ -982,7 +982,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
}
}
{
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
@@ -990,7 +990,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
...
@@ -990,7 +990,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
}
}
{
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
@@ -998,7 +998,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
...
@@ -998,7 +998,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
}
}
#else
#else
{
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
...
@@ -1006,7 +1006,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
...
@@ -1006,7 +1006,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
}
}
{
{
__
asm
__
__
volatile
__
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment