Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
869b7e83
Commit
869b7e83
authored
May 02, 2023
by
Tim Dettmers
Browse files
Warp multi-specialization 240.
parent
77f15fdc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
14 deletions
+56
-14
csrc/kernels.cu
csrc/kernels.cu
+52
-10
tests/test_functional.py
tests/test_functional.py
+4
-4
No files found.
csrc/kernels.cu
View file @
869b7e83
...
@@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
T
local_A
[
1
];
T
local_A
[
2
];
T
local_B
[
32
];
T
local_B
[
64
];
const
int
a_tile_offset
=
16
;
const
int
a_tile_offset
=
16
;
const
int
b_tile_offset
=
(
16
*
32
+
16
);
const
int
b_tile_offset
=
(
16
*
32
+
16
);
...
@@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
int
ticktock
=
0
;
int
ticktock
=
0
;
int
idx
=
0
+
threadIdx
.
x
;
int
idx
=
0
+
threadIdx
.
x
;
int
loaded_values
=
0
;
// prefetch
// prefetch
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
{
if
(
loaded_values
==
0
)
{
{
local_A
[
0
]
=
A
[
idx
];
local_A
[
0
]
=
A
[
idx
];
local_A
[
1
]
=
A
[
idx
+
blockDim
.
x
-
32
];
#pragma unroll 32
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
for
(
int
col
=
0
;
col
<
32
;
col
++
)
{
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
local_B
[
col
+
32
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
blockDim
.
x
-
32
];
}
loaded_values
=
1
;
}
else
{
local_A
[
0
]
=
local_A
[
1
];
loaded_values
--
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
32
];
}
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
...
@@ -3112,12 +3130,36 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3112,12 +3130,36 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
__syncthreads
();
__syncthreads
();
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
{
//local_A[0] = A[idx];
//#pragma unroll 32
//for(int col = 0; col < 32; col++)
// local_B[col] = B[(col_offset+col)*ldb+idx];
if
(
loaded_values
==
0
)
{
{
local_A
[
0
]
=
A
[
idx
];
local_A
[
0
]
=
A
[
idx
];
local_A
[
1
]
=
A
[
idx
+
blockDim
.
x
-
32
];
#pragma unroll 32
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
for
(
int
col
=
0
;
col
<
32
;
col
++
)
{
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
local_B
[
col
+
32
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
+
blockDim
.
x
-
32
];
}
loaded_values
=
1
;
}
else
{
local_A
[
0
]
=
local_A
[
1
];
loaded_values
--
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
local_B
[
col
+
32
];
}
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
...
...
tests/test_functional.py
View file @
869b7e83
...
@@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype):
...
@@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype):
#print('')
#print('')
#print(A)
#print(A)
#print(B.t())
#print(B.t())
#A[:, :-
3
] = 0
#A[:, :-
1
] = 0
#B[:, :-
3
] = 0
#B[:, :-
1
] = 0
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C1
=
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype):
...
@@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype):
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print('')
# print(i, err,
mag.item(), relerr.item()
)
# print(i, err,
relerr
)
# print(A.flatten()[-6:])
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
# out = A.flatten()[-6:]*B.flatten()[-6:]
...
@@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype):
...
@@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype):
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
Fals
e
)
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
Tru
e
)
#print(c/math.sqrt(dim))
#print(c/math.sqrt(dim))
print
(
''
)
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment