Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
eefbf602
Commit
eefbf602
authored
Jul 08, 2023
by
Tim Dettmers
Browse files
Turning optimization (float accumulation). 185 vs 50.
parent
7e49b5b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
38 deletions
+22
-38
csrc/kernels.cu
csrc/kernels.cu
+17
-33
tests/test_functional.py
tests/test_functional.py
+5
-5
No files found.
csrc/kernels.cu
View file @
eefbf602
...
@@ -3528,29 +3528,26 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3528,29 +3528,26 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 4 warps -> 4 loads per iter
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
// 1x128 * 128x4 -> 1x4 outputs
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
//typedef cub::WarpReduce<T> WarpReduce;
typedef
cub
::
WarpReduce
<
float
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
const
int
warp_idx
=
threadIdx
.
x
/
32
;
const
int
warp_idx
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
T
local_C
=
T
(
0
);
//T local_C = T(0.0f);
float
local_C
=
0.0
f
;
T
lane_quant_value
=
nf4_data
[
warp_lane
%
16
];
unsigned
char
local_B_4bit
[
num_values_8bit
];
unsigned
char
local_B_4bit
[
num_values_8bit
];
T
local_B
[
num_values_4bit
];
T
local_B
[
num_values_4bit
];
T
local_A
[
num_values_4bit
];
T
local_A
[
num_values_4bit
];
__shared__
T
quant_map
[
16
*
THREADS
];
__shared__
T
quant_map
[
16
];
__shared__
T
quant_map2
[
16
];
T
local_absmax
=
T
(
0.0
f
);
//for(int i = 0; i < 16; i++)
// quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
//__syncthreads();
for
(
int
i
=
0
;
i
<
16
;
i
++
)
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
quant_map2
[
i
]
=
nf4_data
[
i
];
quant_map
[
i
]
=
nf4_data
[
i
];
__syncthreads
();
// A: [1, K]
// A: [1, K]
// B: [N, K]
// B: [N, K]
...
@@ -3559,7 +3556,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3559,7 +3556,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
int
inner_idx_halved
=
inner_idx
/
2
;
int
inner_idx_halved
=
inner_idx
/
2
;
int
offset_B
=
ldb
*
row_B
;
int
offset_B
=
ldb
*
row_B
;
int
absidx
=
((
2
*
offset_B
)
+
inner_idx
)
/
blocksize
;
int
absidx
=
((
2
*
offset_B
)
+
inner_idx
)
/
blocksize
;
T
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
if
(
row_B
<
M
)
if
(
row_B
<
M
)
{
{
...
@@ -3576,25 +3573,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3576,25 +3573,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
}
}
}
}
if
(
inner_idx
+
(
num_values_4bit
*
32
)
<
K
)
#pragma unroll
{
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
// full warp is running
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
local_B
[
k
*
2
]
=
__shfl_sync
(
0xffffffff
,
lane_quant_value
,
local_B_4bit
[
k
]
>>
4
)
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
__shfl_sync
(
0xffffffff
,
lane_quant_value
,
local_B_4bit
[
k
]
&
0x0F
)
*
local_absmax
;
}
}
else
{
{
// part of the warp exited already
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
#pragma unroll
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
local_B
[
k
*
2
]
=
quant_map2
[(
local_B_4bit
[
k
]
>>
4
)]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map2
[(
local_B_4bit
[
k
]
&
0x0F
)]
*
local_absmax
;
}
}
}
if
(
inner_idx
+
num_values_4bit
)
if
(
inner_idx
+
num_values_4bit
)
...
@@ -3603,6 +3586,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3603,6 +3586,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
3
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
3
];
}
}
else
else
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
...
@@ -3610,14 +3594,14 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3610,14 +3594,14 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_C
+=
local_A
[
k
]
*
local_B
[
k
];
local_C
+=
(
float
)(
local_A
[
k
]
*
local_B
[
k
]
)
;
}
}
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
if
(
row_B
<
M
&&
warp_lane
==
0
)
if
(
row_B
<
M
&&
warp_lane
==
0
)
out
[
row_B
]
=
local_C
;
out
[
row_B
]
=
T
(
local_C
)
;
}
}
...
...
tests/test_functional.py
View file @
eefbf602
...
@@ -2420,7 +2420,7 @@ def test_cutlass3_gemm(dtype):
...
@@ -2420,7 +2420,7 @@ def test_cutlass3_gemm(dtype):
def
test_gemm_4bit
(
dtype
):
def
test_gemm_4bit
(
dtype
):
print
(
''
)
print
(
''
)
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
for
dim
in
[
4
096
]:
for
dim
in
[
4
*
1024
]:
errs
=
[]
errs
=
[]
relerrs
=
[]
relerrs
=
[]
max_err
=
0
max_err
=
0
...
@@ -2485,10 +2485,10 @@ def test_gemm_4bit(dtype):
...
@@ -2485,10 +2485,10 @@ def test_gemm_4bit(dtype):
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
#print(dim, (max_err.item(), max_relerr.item()))
#print(dim, (max_err.item(), max_relerr.item()))
#
print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
0.00015
)
#
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
0.0015
)
#
assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
0.011
#
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.15
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
def
test_managed
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment