Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c82f51c0
Commit
c82f51c0
authored
Jul 19, 2023
by
Tim Dettmers
Browse files
Increased occupancy.
parent
e229fbce
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
58 deletions
+52
-58
bitsandbytes/cuda_setup/main.py
bitsandbytes/cuda_setup/main.py
+1
-1
csrc/kernels.cu
csrc/kernels.cu
+44
-50
tests/test_functional.py
tests/test_functional.py
+7
-7
No files found.
bitsandbytes/cuda_setup/main.py
View file @
c82f51c0
csrc/kernels.cu
View file @
c82f51c0
...
@@ -3540,8 +3540,8 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3540,8 +3540,8 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
float
local_C
=
0.0
f
;
float
local_C
=
0.0
f
;
unsigned
char
local_B_4bit
[
num_values_8bit
];
unsigned
char
local_B_4bit
[
num_values_8bit
];
T
local_B
[
num_values_4bit
];
T
local_B
[
num_values_4bit
/
4
];
T
local_A
[
num_values_4bit
];
T
local_A
[
num_values_4bit
/
4
];
__shared__
T
quant_map
[
16
];
__shared__
T
quant_map
[
16
];
T
local_absmax
=
T
(
0.0
f
);
T
local_absmax
=
T
(
0.0
f
);
...
@@ -3582,54 +3582,47 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3582,54 +3582,47 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
local_B_4bit
[
j
]
=
0b01110111
;
local_B_4bit
[
j
]
=
0b01110111
;
}
}
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_8bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_8bit
/
4
;
k
++
)
{
{
#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
(
i
*
num_values_8bit
/
4
)
+
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
(
i
*
num_values_8bit
/
4
)
+
k
]
&
0x0F
]
*
local_absmax
;
#else
#else
// bf16 multipliation not supported
// bf16 multipliation not supported
local_B
[
k
*
2
]
=
T
((
float
)
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
(
float
)
local_absmax
);
local_B
[
k
*
2
]
=
T
((
float
)
quant_map
[
local_B_4bit
[
(
i
*
num_values_8bit
/
4
)
+
k
]
>>
4
]
*
(
float
)
local_absmax
);
local_B
[
k
*
2
+
1
]
=
T
((
float
)
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
(
float
)
local_absmax
);
local_B
[
k
*
2
+
1
]
=
T
((
float
)
quant_map
[
local_B_4bit
[
(
i
*
num_values_8bit
/
4
)
+
k
]
&
0x0F
]
*
(
float
)
local_absmax
);
#endif
#endif
}
}
if
(
inner_idx
+
num_values_4bit
<
K
)
if
(
inner_idx
+
(
num_values_4bit
/
4
)
+
(
i
*
num_values_4bit
/
4
)
<
K
)
{
{
// this is also relatively important for performance
// this is also relatively important for performance
if
(
BITS
==
16
)
if
(
BITS
==
16
)
{
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
i
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
3
];
}
}
else
else
{
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
(
2
*
i
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
(
2
*
i
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
3
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
4
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
4
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
5
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
5
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
6
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
6
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
7
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
7
];
}
}
}
}
else
else
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
/
4
;
k
++
)
if
(
inner_idx
+
k
<
K
)
if
(
inner_idx
+
(
i
*
num_values_4bit
/
4
)
+
k
<
K
)
local_A
[
k
]
=
A
[
inner_idx
+
k
];
local_A
[
k
]
=
A
[
inner_idx
+
k
+
(
i
*
num_values_4bit
/
4
)
];
else
else
local_A
[
k
]
=
T
(
0.0
f
);
local_A
[
k
]
=
T
(
0.0
f
);
// accumulate in float; small performance hit for Ampere, but lower error for outputs
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
/
4
;
k
++
)
{
{
#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800
local_C
+=
(
float
)(
local_A
[
k
]
*
local_B
[
k
]);
local_C
+=
(
float
)(
local_A
[
k
]
*
local_B
[
k
]);
...
@@ -3639,6 +3632,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3639,6 +3632,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
#endif
#endif
}
}
}
}
}
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
...
...
tests/test_functional.py
View file @
c82f51c0
...
@@ -2366,7 +2366,7 @@ def test_normal_map_tree():
...
@@ -2366,7 +2366,7 @@ def test_normal_map_tree():
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
,
kind
):
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
,
kind
):
for
dim
in
[
128
,
256
,
512
,
1024
]:
for
dim
in
[
128
,
256
,
512
,
1024
]:
#for dim in [4*1024]:
#for dim in [4*1024]:
#for dim in [1*1
28
]:
#for dim in [1*1
6
]:
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
errs3
=
[]
errs3
=
[]
...
@@ -2446,11 +2446,11 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
...
@@ -2446,11 +2446,11 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
#
#
#print('='*80)
#print('='*80)
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#
print(C1.flatten()[-20:])
print
(
C1
.
flatten
()[
-
20
:])
#
print(C2.flatten()[-20:])
print
(
C2
.
flatten
()[
-
20
:])
#
print(f'inference vs training abs: {err1}')
print
(
f
'inference vs training abs:
{
err1
}
'
)
#
print(f'inference vs training rel: {relerr1}')
print
(
f
'inference vs training rel:
{
relerr1
}
'
)
#
print(f'inference vs training max: {maxerr1}')
print
(
f
'inference vs training max:
{
maxerr1
}
'
)
#print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}')
...
@@ -2478,7 +2478,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
...
@@ -2478,7 +2478,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
assert
maxratio
<
1.005
and
maxratio
>
0.995
assert
maxratio
<
1.005
and
maxratio
>
0.995
elif
dtype
==
torch
.
bfloat16
:
elif
dtype
==
torch
.
bfloat16
:
if
dim
<=
512
:
if
dim
<=
512
:
assert
err1
<
5
e-4
assert
err1
<
6
e-4
assert
relerr1
<
0.007
assert
relerr1
<
0.007
assert
maxerr1
<
0.015
assert
maxerr1
<
0.015
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment