Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
9192c9de
Commit
9192c9de
authored
May 02, 2023
by
Tim Dettmers
Browse files
Tighter and scaled error analysis.
parent
f9bfea8f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
42 deletions
+70
-42
csrc/kernels.cu
csrc/kernels.cu
+14
-1
tests/test_functional.py
tests/test_functional.py
+56
-41
No files found.
csrc/kernels.cu
View file @
9192c9de
...
@@ -3123,6 +3123,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3123,6 +3123,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
}
}
ticktock
=
ticktock
==
0
?
1
:
0
;
ticktock
=
ticktock
==
0
?
1
:
0
;
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
{
{
idx
=
base_idx
+
threadIdx
.
x
;
idx
=
base_idx
+
threadIdx
.
x
;
...
@@ -3155,8 +3156,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3155,8 +3156,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
for
(
int
col
=
0
;
col
<
32
;
col
++
)
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
}
}
ticktock
=
ticktock
==
0
?
1
:
0
;
//
ticktock = ticktock == 0 ? 1 : 0;
__syncthreads
();
if
(
warp_id
==
(
WARPS
-
1
))
if
(
warp_id
==
(
WARPS
-
1
))
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
{
{
...
@@ -3166,11 +3168,22 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3166,11 +3168,22 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
}
}
}
}
//__syncthreads();
//if(warp_id == (WARPS-1))
// for(int k = 0; k < batch_size_warps; k++)
// {
// wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
// wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
// wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
// }
__syncthreads
();
// 129 mu
// 129 mu
if
(
warp_id
==
(
WARPS
-
1
))
if
(
warp_id
==
(
WARPS
-
1
))
wmma
::
store_matrix_sync
(
&
(
smem_C
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
wmma
::
store_matrix_sync
(
&
(
smem_C
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
__syncthreads
();
__syncthreads
();
//if(threadIdx.x >= 16){ return; }
//if(threadIdx.x >= 16){ return; }
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
...
...
tests/test_functional.py
View file @
9192c9de
...
@@ -2355,47 +2355,62 @@ def test_normal_map_tree():
...
@@ -2355,47 +2355,62 @@ def test_normal_map_tree():
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
def
test_cutlass3_gemm
(
dtype
):
for
i
in
range
(
100
):
for
dim
in
[
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
errs
=
[]
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
relerrs
=
[]
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
max_err
=
0
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
max_relerr
=
0
A
=
torch
.
randn
(
1
,
128
+
32
,
dtype
=
dtype
,
device
=
'cuda'
)
for
i
in
range
(
100
):
B
=
torch
.
randn
(
4096
,
128
+
32
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
128
)
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#print('')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#print(A)
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
#print(B.t())
A
=
torch
.
randn
(
1
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
#A[:, :-3] = 0
B
=
torch
.
randn
(
4
*
496
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B[:, :-3] = 0
#print('')
#print(A)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
#print(B.t())
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
#A[:, :-3] = 0
err
=
C1
-
C2
#B[:, :-3] = 0
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
C1
=
torch
.
matmul
(
A
,
B
.
t
())
# to test our implementation
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
err
=
torch
.
abs
(
err
.
mean
()).
item
()
mag
=
torch
.
abs
(
C1
).
mean
()
# tensor cores are non-deterministic
relerr
=
err
/
mag
# so we need to analyze errors around the mean
# to test our implementation
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
err
=
torch
.
abs
(
C1
-
C2
)
print
(
''
)
mag
=
torch
.
abs
(
C1
)
+
1e-8
print
(
i
,
err
,
mag
.
item
(),
relerr
.
item
())
relerr
=
err
/
mag
print
(
A
.
flatten
()[
-
6
:])
max_err
=
max
(
err
.
max
(),
max_err
)
print
(
B
.
flatten
()[
-
6
:])
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
out
=
A
.
flatten
()[
-
6
:]
*
B
.
flatten
()[
-
6
:]
err
=
err
.
mean
().
item
()
print
(
out
)
relerr
=
relerr
.
mean
().
item
()
print
(
out
[:
-
1
].
sum
())
print
(
'='
*
80
)
errs
.
append
(
err
)
print
(
C1
.
flatten
()[
-
6
:])
relerrs
.
append
(
relerr
)
print
(
C2
.
flatten
()[
-
6
:])
#assert False, 'ERROR'
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
c
=
int
(
C1
.
numel
()
*
0.001
)
# print(i, err, mag.item(), relerr.item())
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
)
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
# print(out)
# print(out[:-1].sum())
# print('='*80)
# print(C1.flatten()[-6:])
# print(C2.flatten()[-6:])
# #assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.00125
*
(
dim
/
256
))
+
1
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
)
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment