Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
dfe6900b
Commit
dfe6900b
authored
Jul 04, 2023
by
Tim Dettmers
Browse files
Vectorized loads, conflict free NF4; 52 vs 172.
parent
f89ff93e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
46 deletions
+44
-46
csrc/kernels.cu
csrc/kernels.cu
+38
-41
csrc/ops.cu
csrc/ops.cu
+1
-0
tests/test_functional.py
tests/test_functional.py
+5
-5
No files found.
csrc/kernels.cu
View file @
dfe6900b
...
@@ -3519,7 +3519,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3519,7 +3519,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
out
[
col_offset
+
warp_lane
]
=
smem_C
[
warp_lane
];
out
[
col_offset
+
warp_lane
]
=
smem_C
[
warp_lane
];
}
}
#define num_values_4bit
16
#define num_values_4bit
32
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
{
...
@@ -3529,72 +3529,68 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3529,72 +3529,68 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
// 4 warps -> 4 loads per iter
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
// 1x128 * 128x4 -> 1x4 outputs
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
4
];
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
const
int
warp_idx
=
threadIdx
.
x
/
32
;
const
int
warp_idx
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
row_B
=
4
*
blockIdx
.
x
+
warp_idx
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
T
local_C
=
T
(
0
);
T
local_C
=
T
(
0
);
T
quant_map
[
16
];
unsigned
char
local_B_4bit
[
num_values_8bit
];
#pragma unroll 16
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
unsigned
char
local_B_4bit
[
num_values_4bit
/
2
];
T
local_B
[
num_values_4bit
];
T
local_B
[
num_values_4bit
];
T
local_A
[
num_values_4bit
];
__shared__
half
quant_map
[
16
*
THREADS
];
// need to increase occupancy by splitting the rows, but can be done later
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
threadIdx
.
x
+
(
i
*
blockDim
.
x
)]
=
nf4_data
[
i
];
__syncthreads
();
// A: [1, K]
// A: [1, K]
// B: [N, K]
// B: [N, K]
for
(
int
inner_idx
=
warp_lane
*
num_values_4bit
;
inner_idx
<
K
;
inner_idx
+=
32
*
num_values_4bit
)
for
(
int
inner_idx
=
warp_lane
*
num_values_4bit
;
inner_idx
<
K
;
inner_idx
+=
32
*
num_values_4bit
)
{
{
int
offset_B
=
ldb
*
row_B
+
(
inner_idx
/
2
);
int
inner_idx_halved
=
inner_idx
/
2
;
int
absidx
=
(
2
*
offset_B
)
/
blocksize
;
int
offset_B
=
ldb
*
row_B
;
int
absidx
=
((
2
*
offset_B
)
+
inner_idx
)
/
blocksize
;
T
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
T
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
//printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B);
if
(
row_B
<
M
)
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
/
2
;
k
++
)
{
{
if
((
inner_idx
/
2
)
<
K
&&
row_B
<
M
)
if
((
inner_idx_halved
+
num_values_8bit
)
<
K
)
local_B_4bit
[
k
]
=
B
[
offset_B
+
k
];
{
reinterpret_cast
<
int4
(
&
)[
num_values_8bit
]
>
(
local_B_4bit
)[
0
]
=
reinterpret_cast
<
int4
*>
(
B
)[(
offset_B
+
(
inner_idx_halved
))
/
(
num_values_8bit
)];
}
else
else
local_B_4bit
[
k
]
=
0b01110111
;
{
#pragma unroll
for
(
int
j
=
0
;
j
<
(
num_values_8bit
);
j
++
)
if
((
inner_idx
/
2
)
+
j
<
K
)
local_B_4bit
[
j
]
=
0b01110111
;
}
}
}
//if(row_B < M)
//{
// if((inner_idx/num_values_4bit) < K)
// reinterpret_cast<int4*>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[offset_B/(num_values_4bit/2)];
// else
// {
// for(int k = 0; k < num_values_4bit/2; k++)
// {
// if((inner_idx/2) < K && row_B < M)
// local_B_4bit[k] = B[offset_B + k];
// else
// local_B_4bit[k] = 0b01110111;
// }
// }
//}
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
{
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
]
=
quant_map
[
(
local_B_4bit
[
k
]
>>
4
)
*
THREADS
+
threadIdx
.
x
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
(
local_B_4bit
[
k
]
&
0x0F
)
*
THREADS
+
threadIdx
.
x
]
*
local_absmax
;
}
}
//printnonzero<T>(local_B, 4, "B values: ");
if
(
inner_idx
+
num_values_4bit
)
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_8bit
/
2
)
+
3
];
}
else
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_A
[
k
]
=
A
[
inner_idx
+
k
];
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_C
+=
A
[
inner_idx
+
k
]
*
local_B
[
k
];
local_C
+=
local_A
[
k
]
*
local_B
[
k
];
}
}
...
@@ -3773,6 +3769,7 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
...
@@ -3773,6 +3769,7 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
32
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
...
...
csrc/ops.cu
View file @
dfe6900b
...
@@ -733,6 +733,7 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
...
@@ -733,6 +733,7 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
{
{
int
num_blocks
=
(
m
+
3
)
/
4
;
int
num_blocks
=
(
m
+
3
)
/
4
;
//int num_blocks = m;
cout
<<
num_blocks
<<
endl
;
cout
<<
num_blocks
<<
endl
;
//cout << lda << endl;
//cout << lda << endl;
...
...
tests/test_functional.py
View file @
dfe6900b
...
@@ -2415,21 +2415,21 @@ def test_gemm_4bit(dtype):
...
@@ -2415,21 +2415,21 @@ def test_gemm_4bit(dtype):
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
#for dim in [32]:
for
dim
in
[
4096
]:
for
dim
in
[
2
*
4096
]:
#for dim in [5120]:
#for dim in [5120]:
#for dim in [6656]:
#for dim in [6656]:
#for dim in [
128
]:
#for dim in [
4
]:
errs
=
[]
errs
=
[]
relerrs
=
[]
relerrs
=
[]
max_err
=
0
max_err
=
0
max_relerr
=
0
max_relerr
=
0
for
i
in
range
(
1
):
for
i
in
range
(
1
00
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
+
2
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
2
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
4
*
dim
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print('')
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment