Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
eefbf602
Commit
eefbf602
authored
Jul 08, 2023
by
Tim Dettmers
Browse files
Turning optimization (float accumulation). 185 vs 50.
parent
7e49b5b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
38 deletions
+22
-38
csrc/kernels.cu
csrc/kernels.cu
+17
-33
tests/test_functional.py
tests/test_functional.py
+5
-5
No files found.
csrc/kernels.cu
View file @
eefbf602
...
...
@@ -3528,29 +3528,26 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
//typedef cub::WarpReduce<T> WarpReduce;
typedef
cub
::
WarpReduce
<
float
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
const
int
warp_idx
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
T
local_C
=
T
(
0
);
T
lane_quant_value
=
nf4_data
[
warp_lane
%
16
];
//T local_C = T(0.0f);
float
local_C
=
0.0
f
;
unsigned
char
local_B_4bit
[
num_values_8bit
];
T
local_B
[
num_values_4bit
];
T
local_A
[
num_values_4bit
];
__shared__
T
quant_map
[
16
*
THREADS
];
__shared__
T
quant_map2
[
16
];
//for(int i = 0; i < 16; i++)
// quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
//__syncthreads();
__shared__
T
quant_map
[
16
];
T
local_absmax
=
T
(
0.0
f
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map2
[
i
]
=
nf4_data
[
i
];
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
__syncthreads
();
// A: [1, K]
// B: [N, K]
...
...
@@ -3559,7 +3556,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
int
inner_idx_halved
=
inner_idx
/
2
;
int
offset_B
=
ldb
*
row_B
;
int
absidx
=
((
2
*
offset_B
)
+
inner_idx
)
/
blocksize
;
T
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
if
(
row_B
<
M
)
{
...
...
@@ -3576,25 +3573,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
}
}
if
(
inner_idx
+
(
num_values_4bit
*
32
)
<
K
)
{
// full warp is running
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
local_B
[
k
*
2
]
=
__shfl_sync
(
0xffffffff
,
lane_quant_value
,
local_B_4bit
[
k
]
>>
4
)
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
__shfl_sync
(
0xffffffff
,
lane_quant_value
,
local_B_4bit
[
k
]
&
0x0F
)
*
local_absmax
;
}
}
else
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
// part of the warp exited already
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
local_B
[
k
*
2
]
=
quant_map2
[(
local_B_4bit
[
k
]
>>
4
)]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map2
[(
local_B_4bit
[
k
]
&
0x0F
)]
*
local_absmax
;
}
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
}
if
(
inner_idx
+
num_values_4bit
)
...
...
@@ -3603,6 +3586,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
3
];
}
else
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
...
...
@@ -3610,14 +3594,14 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_C
+=
local_A
[
k
]
*
local_B
[
k
];
local_C
+=
(
float
)(
local_A
[
k
]
*
local_B
[
k
]
)
;
}
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
if
(
row_B
<
M
&&
warp_lane
==
0
)
out
[
row_B
]
=
local_C
;
out
[
row_B
]
=
T
(
local_C
)
;
}
...
...
tests/test_functional.py
View file @
eefbf602
...
...
@@ -2420,7 +2420,7 @@ def test_cutlass3_gemm(dtype):
def
test_gemm_4bit
(
dtype
):
print
(
''
)
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
for
dim
in
[
4
096
]:
for
dim
in
[
4
*
1024
]:
errs
=
[]
relerrs
=
[]
max_err
=
0
...
...
@@ -2485,10 +2485,10 @@ def test_gemm_4bit(dtype):
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
#print(dim, (max_err.item(), max_relerr.item()))
#
print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
#
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
#
assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
#
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
0.00015
)
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
0.0015
)
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
0.011
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.15
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment