Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
f9bfea8f
Commit
f9bfea8f
authored
May 02, 2023
by
Tim Dettmers
Browse files
Baseline for debugging.
parent
7bfa09d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
19 deletions
+66
-19
bitsandbytes/functional.py
bitsandbytes/functional.py
+1
-1
csrc/kernels.cu
csrc/kernels.cu
+28
-3
csrc/ops.cu
csrc/ops.cu
+8
-8
tests/test_functional.py
tests/test_functional.py
+29
-7
No files found.
bitsandbytes/functional.py
View file @
f9bfea8f
...
...
@@ -1467,7 +1467,7 @@ def cutlass3_gemm(
lda
=
Bshape
[
1
]
ldc
=
Bshape
[
0
]
ldb
=
(
ldb
+
1
)
//
2
print
(
m
,
n
,
k
,
lda
,
ldb
,
ldc
)
#
print(m, n, k, lda, ldb, ldc)
is_on_gpu
([
B
,
A
,
out
])
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
...
...
csrc/kernels.cu
View file @
f9bfea8f
...
...
@@ -3061,9 +3061,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
T
local_A
[
1
];
T
local_B
[
32
];
const
int
a_tile_offset
=
(
8
*
16
+
16
);
const
int
b_tile_offset
=
(
16
*
32
+
16
);
const
int
c_tile_offset
=
8
*
32
+
24
;
const
int
a_tile_offset
=
(
8
*
16
);
const
int
b_tile_offset
=
(
16
*
32
);
__shared__
T
smem_A
[
2
*
batch_size_warps
*
8
*
16
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_B
[
2
*
batch_size_warps
*
16
*
32
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
...
...
@@ -3109,6 +3108,19 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(
half_warp_id
*
b_tile_offset
)
+
(
col
*
16
)]
=
local_B
[
col
];
}
else
if
(
warp_id
<
(
WARPS
-
1
))
{
local_A
[
0
]
=
T
(
0.0
);
smem_A
[
half_warp_lane
+
(
half_warp_id
*
a_tile_offset
)]
=
T
(
0.0
);
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
T
(
0.0
f
);
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(
half_warp_id
*
b_tile_offset
)
+
(
col
*
16
)]
=
T
(
0.0
f
);
}
ticktock
=
ticktock
==
0
?
1
:
0
;
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
...
...
@@ -3130,6 +3142,19 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
local_B
[
col
];
}
else
if
(
warp_id
<
(
WARPS
-
1
))
{
local_A
[
0
]
=
T
(
0.0
);
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
0.0
f
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
local_B
[
col
]
=
0.0
f
;
#pragma unroll 32
for
(
int
col
=
0
;
col
<
32
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
}
ticktock
=
ticktock
==
0
?
1
:
0
;
if
(
warp_id
==
(
WARPS
-
1
))
...
...
csrc/ops.cu
View file @
f9bfea8f
...
...
@@ -680,14 +680,14 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
int
num_blocks
=
(
m
+
31
)
/
32
;
cout
<<
num_blocks
<<
endl
;
cout
<<
lda
<<
endl
;
cout
<<
ldb
<<
endl
;
cout
<<
ldc
<<
endl
;
cout
<<
m
<<
endl
;
cout
<<
n
<<
endl
;
cout
<<
k
<<
endl
;
//
cout << num_blocks << endl;
//
cout << lda << endl;
//
cout << ldb << endl;
//
cout << ldc << endl;
//
cout << m << endl;
//
cout << n << endl;
//
cout << k << endl;
//if(bits == 32)
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
...
...
tests/test_functional.py
View file @
f9bfea8f
...
...
@@ -2355,25 +2355,47 @@ def test_normal_map_tree():
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
for
i
in
range
(
1
):
for
i
in
range
(
1
00
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
rand
(
1
,
4096
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
rand
(
4
*
4096
,
4096
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
rand
n
(
1
,
128
+
32
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
rand
n
(
4096
,
128
+
32
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
128
)
#print('')
#print(A)
#print(B.t())
#A[:, :-3] = 0
#B[:, :-3] = 0
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
print
(
C1
)
print
(
C2
)
torch
.
testing
.
assert_close
(
C1
,
C2
,
atol
=
1e-05
,
rtol
=
0.06
)
err
=
C1
-
C2
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err
=
torch
.
abs
(
err
.
mean
()).
item
()
mag
=
torch
.
abs
(
C1
).
mean
()
relerr
=
err
/
mag
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
print
(
''
)
print
(
i
,
err
,
mag
.
item
(),
relerr
.
item
())
print
(
A
.
flatten
()[
-
6
:])
print
(
B
.
flatten
()[
-
6
:])
out
=
A
.
flatten
()[
-
6
:]
*
B
.
flatten
()[
-
6
:]
print
(
out
)
print
(
out
[:
-
1
].
sum
())
print
(
'='
*
80
)
print
(
C1
.
flatten
()[
-
6
:])
print
(
C2
.
flatten
()[
-
6
:])
#assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.001
)
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment