Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
306f6b23
Commit
306f6b23
authored
Jul 10, 2023
by
Tim Dettmers
Browse files
Fixed accidential deletion of limits in kernel.
parent
2221f4ce
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
6 deletions
+17
-6
csrc/kernels.cu
csrc/kernels.cu
+8
-2
tests/test_functional.py
tests/test_functional.py
+9
-4
No files found.
csrc/kernels.cu
View file @
306f6b23
...
@@ -3595,7 +3595,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3595,7 +3595,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
#endif
#endif
}
}
if
(
inner_idx
+
num_values_4bit
)
if
(
inner_idx
+
num_values_4bit
<
K
)
{
{
if
(
BITS
==
16
)
if
(
BITS
==
16
)
{
{
...
@@ -3619,11 +3619,17 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3619,11 +3619,17 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
}
}
else
else
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_A
[
k
]
=
A
[
inner_idx
+
k
];
if
(
inner_idx
+
k
<
K
)
local_A
[
k
]
=
A
[
inner_idx
+
k
];
else
local_A
[
k
]
=
T
(
0.0
f
);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
{
if
((
float
)
local_A
[
k
]
<
-
10.0
f
||
(
float
)
local_B
[
k
]
<
-
10.0
f
||
local_C
>
10.0
f
)
printf
(
"%f %f = %f
\n
"
,
(
float
)
local_A
[
k
],
(
float
)
local_B
[
k
],
local_C
);
#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800
local_C
+=
(
float
)(
local_A
[
k
]
*
local_B
[
k
]);
local_C
+=
(
float
)(
local_A
[
k
]
*
local_B
[
k
]);
#else
#else
...
...
tests/test_functional.py
View file @
306f6b23
...
@@ -2378,8 +2378,8 @@ def test_gemv_4bit(dtype, storage_type, double_quant):
...
@@ -2378,8 +2378,8 @@ def test_gemv_4bit(dtype, storage_type, double_quant):
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#
B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#
B = torch.randn(
1
, dim
+2
, dtype=dtype, device='cuda')/math.sqrt(dim)
B
=
torch
.
randn
(
dim
*
4
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#print('')
#print('')
#print(A)
#print(A)
...
@@ -2432,13 +2432,18 @@ def test_gemv_4bit(dtype, storage_type, double_quant):
...
@@ -2432,13 +2432,18 @@ def test_gemv_4bit(dtype, storage_type, double_quant):
#print(dim, (max_err.item(), max_relerr.item()))
#print(dim, (max_err.item(), max_relerr.item()))
print
(
C1
.
flatten
()[
-
20
:])
print
(
C1
.
flatten
()[
-
20
:])
print
(
C2
.
flatten
()[
-
20
:])
print
(
C2
.
flatten
()[
-
20
:])
print
(
C3
.
flatten
()[
-
20
:])
#print(C1.flatten())
#print(C2.flatten())
#print(C3.flatten()[-20:])
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
dim
)
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
dim
)
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
dim
)
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
dim
)
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
else
:
elif
dtype
==
torch
.
float32
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-8
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
1e-8
elif
dtype
==
torch
.
bfloat16
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
3e-4
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
3e-4
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.003
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.003
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment