Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
dfe6900b
Commit
dfe6900b
authored
Jul 04, 2023
by
Tim Dettmers
Browse files
Vectorized loads, conflict free NF4; 52 vs 172.
parent
f89ff93e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
46 deletions
+44
-46
csrc/kernels.cu
csrc/kernels.cu
+38
-41
csrc/ops.cu
csrc/ops.cu
+1
-0
tests/test_functional.py
tests/test_functional.py
+5
-5
No files found.
csrc/kernels.cu
View file @
dfe6900b
...
...
@@ -3519,7 +3519,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
out
[
col_offset
+
warp_lane
]
=
smem_C
[
warp_lane
];
}
#define num_values_4bit
16
#define num_values_4bit
32
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
...
...
@@ -3529,72 +3529,68 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
4
];
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
const
int
warp_idx
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
row_B
=
4
*
blockIdx
.
x
+
warp_idx
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
T
local_C
=
T
(
0
);
T
quant_map
[
16
];
#pragma unroll 16
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
unsigned
char
local_B_4bit
[
num_values_4bit
/
2
];
unsigned
char
local_B_4bit
[
num_values_8bit
];
T
local_B
[
num_values_4bit
];
T
local_A
[
num_values_4bit
];
__shared__
half
quant_map
[
16
*
THREADS
];
// need to increase occupancy by splitting the rows, but can be done later
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
threadIdx
.
x
+
(
i
*
blockDim
.
x
)]
=
nf4_data
[
i
];
__syncthreads
();
// A: [1, K]
// B: [N, K]
for
(
int
inner_idx
=
warp_lane
*
num_values_4bit
;
inner_idx
<
K
;
inner_idx
+=
32
*
num_values_4bit
)
{
int
offset_B
=
ldb
*
row_B
+
(
inner_idx
/
2
);
int
absidx
=
(
2
*
offset_B
)
/
blocksize
;
int
inner_idx_halved
=
inner_idx
/
2
;
int
offset_B
=
ldb
*
row_B
;
int
absidx
=
((
2
*
offset_B
)
+
inner_idx
)
/
blocksize
;
T
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
//printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B);
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
/
2
;
k
++
)
if
(
row_B
<
M
)
{
if
((
inner_idx
/
2
)
<
K
&&
row_B
<
M
)
local_B_4bit
[
k
]
=
B
[
offset_B
+
k
];
if
((
inner_idx_halved
+
num_values_8bit
)
<
K
)
{
reinterpret_cast
<
int4
(
&
)[
num_values_8bit
]
>
(
local_B_4bit
)[
0
]
=
reinterpret_cast
<
int4
*>
(
B
)[(
offset_B
+
(
inner_idx_halved
))
/
(
num_values_8bit
)];
}
else
local_B_4bit
[
k
]
=
0b01110111
;
{
#pragma unroll
for
(
int
j
=
0
;
j
<
(
num_values_8bit
);
j
++
)
if
((
inner_idx
/
2
)
+
j
<
K
)
local_B_4bit
[
j
]
=
0b01110111
;
}
}
//if(row_B < M)
//{
// if((inner_idx/num_values_4bit) < K)
// reinterpret_cast<int4*>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[offset_B/(num_values_4bit/2)];
// else
// {
// for(int k = 0; k < num_values_4bit/2; k++)
// {
// if((inner_idx/2) < K && row_B < M)
// local_B_4bit[k] = B[offset_B + k];
// else
// local_B_4bit[k] = 0b01110111;
// }
// }
//}
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
local_B
[
k
*
2
]
=
quant_map
[
(
local_B_4bit
[
k
]
>>
4
)
*
THREADS
+
threadIdx
.
x
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
(
local_B_4bit
[
k
]
&
0x0F
)
*
THREADS
+
threadIdx
.
x
]
*
local_absmax
;
}
//printnonzero<T>(local_B, 4, "B values: ");
if
(
inner_idx
+
num_values_4bit
)
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
3
];
}
else
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_A
[
k
]
=
A
[
inner_idx
+
k
];
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_C
+=
A
[
inner_idx
+
k
]
*
local_B
[
k
];
local_C
+=
local_A
[
k
]
*
local_B
[
k
];
}
...
...
@@ -3773,6 +3769,7 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
32
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
...
...
csrc/ops.cu
View file @
dfe6900b
...
...
@@ -733,6 +733,7 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
{
int
num_blocks
=
(
m
+
3
)
/
4
;
//int num_blocks = m;
cout
<<
num_blocks
<<
endl
;
//cout << lda << endl;
...
...
tests/test_functional.py
View file @
dfe6900b
...
...
@@ -2415,21 +2415,21 @@ def test_gemm_4bit(dtype):
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for
dim
in
[
4096
]:
for
dim
in
[
2
*
4096
]:
#for dim in [5120]:
#for dim in [6656]:
#for dim in [
128
]:
#for dim in [
4
]:
errs
=
[]
relerrs
=
[]
max_err
=
0
max_relerr
=
0
for
i
in
range
(
1
):
for
i
in
range
(
1
00
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
+
2
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
2
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment