Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
9192c9de
Commit
9192c9de
authored
May 02, 2023
by
Tim Dettmers
Browse files
Tighter and scaled error analysis.
parent
f9bfea8f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
42 deletions
+70
-42
csrc/kernels.cu
csrc/kernels.cu
+14
-1
tests/test_functional.py
tests/test_functional.py
+56
-41
No files found.
csrc/kernels.cu
View file @
9192c9de
...
@@ -3123,6 +3123,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3123,6 +3123,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
}
}
ticktock
=
ticktock
==
0
?
1
:
0
;
ticktock
=
ticktock
==
0
?
1
:
0
;
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
{
{
idx
=
base_idx
+
threadIdx
.
x
;
idx
=
base_idx
+
threadIdx
.
x
;
...
@@ -3155,8 +3156,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3155,8 +3156,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
for
(
int
col
=
0
;
col
<
32
;
col
++
)
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
}
}
ticktock
=
ticktock
==
0
?
1
:
0
;
//
ticktock = ticktock == 0 ? 1 : 0;
__syncthreads
();
if
(
warp_id
==
(
WARPS
-
1
))
if
(
warp_id
==
(
WARPS
-
1
))
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
{
{
...
@@ -3166,11 +3168,22 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3166,11 +3168,22 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
}
}
}
}
//__syncthreads();
//if(warp_id == (WARPS-1))
// for(int k = 0; k < batch_size_warps; k++)
// {
// wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
// wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
// wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
// }
__syncthreads
();
// 129 mu
// 129 mu
if
(
warp_id
==
(
WARPS
-
1
))
if
(
warp_id
==
(
WARPS
-
1
))
wmma
::
store_matrix_sync
(
&
(
smem_C
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
wmma
::
store_matrix_sync
(
&
(
smem_C
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
__syncthreads
();
__syncthreads
();
//if(threadIdx.x >= 16){ return; }
//if(threadIdx.x >= 16){ return; }
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
...
...
tests/test_functional.py
View file @
9192c9de
...
@@ -2355,47 +2355,62 @@ def test_normal_map_tree():
...
@@ -2355,47 +2355,62 @@ def test_normal_map_tree():
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
def
test_cutlass3_gemm
(
dtype
):
for
i
in
range
(
100
):
for
dim
in
[
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
errs
=
[]
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
relerrs
=
[]
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
max_err
=
0
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
max_relerr
=
0
A
=
torch
.
randn
(
1
,
128
+
32
,
dtype
=
dtype
,
device
=
'cuda'
)
for
i
in
range
(
100
):
B
=
torch
.
randn
(
4096
,
128
+
32
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
128
)
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#print('')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#print(A)
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
#print(B.t())
A
=
torch
.
randn
(
1
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
#A[:, :-3] = 0
B
=
torch
.
randn
(
4
*
496
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B[:, :-3] = 0
#print('')
#print(A)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
#print(B.t())
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
#A[:, :-3] = 0
err
=
C1
-
C2
#B[:, :-3] = 0
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
C1
=
torch
.
matmul
(
A
,
B
.
t
())
# to test our implementation
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
err
=
torch
.
abs
(
err
.
mean
()).
item
()
mag
=
torch
.
abs
(
C1
).
mean
()
# tensor cores are non-deterministic
relerr
=
err
/
mag
# so we need to analyze errors around the mean
# to test our implementation
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
err
=
torch
.
abs
(
C1
-
C2
)
print
(
''
)
mag
=
torch
.
abs
(
C1
)
+
1e-8
print
(
i
,
err
,
mag
.
item
(),
relerr
.
item
())
relerr
=
err
/
mag
print
(
A
.
flatten
()[
-
6
:])
max_err
=
max
(
err
.
max
(),
max_err
)
print
(
B
.
flatten
()[
-
6
:])
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
out
=
A
.
flatten
()[
-
6
:]
*
B
.
flatten
()[
-
6
:]
err
=
err
.
mean
().
item
()
print
(
out
)
relerr
=
relerr
.
mean
().
item
()
print
(
out
[:
-
1
].
sum
())
print
(
'='
*
80
)
errs
.
append
(
err
)
print
(
C1
.
flatten
()[
-
6
:])
relerrs
.
append
(
relerr
)
print
(
C2
.
flatten
()[
-
6
:])
#assert False, 'ERROR'
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
c
=
int
(
C1
.
numel
()
*
0.001
)
# print(i, err, mag.item(), relerr.item())
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
)
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
# print(out)
# print(out[:-1].sum())
# print('='*80)
# print(C1.flatten()[-6:])
# print(C2.flatten()[-6:])
# #assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.00125
*
(
dim
/
256
))
+
1
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
)
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment