Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
f6df4aef
Commit
f6df4aef
authored
Apr 28, 2023
by
Tim Dettmers
Browse files
Added fp16 and thread/item template.
parent
3aef7834
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
35 deletions
+53
-35
bitsandbytes/functional.py
bitsandbytes/functional.py
+8
-3
csrc/kernels.cu
csrc/kernels.cu
+20
-19
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+2
-1
csrc/pythonInterface.c
csrc/pythonInterface.c
+5
-0
tests/test_functional.py
tests/test_functional.py
+17
-11
No files found.
bitsandbytes/functional.py
View file @
f6df4aef
...
...
@@ -1381,9 +1381,9 @@ def cutlass3_gemm(
transposed_A
=
False
,
transposed_B
=
False
,
):
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
float32
)
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
A
.
dtype
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
A
.
dtype
,
device
=
A
.
device
)
sA
=
A
.
shape
sB
=
B
.
shape
...
...
@@ -1464,7 +1464,12 @@ def cutlass3_gemm(
lda
=
ct
.
c_int32
(
lda
)
ldb
=
ct
.
c_int32
(
ldb
)
ldc
=
ct
.
c_int32
(
ldc
)
if
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_host_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_host_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
return
out
...
...
csrc/kernels.cu
View file @
f6df4aef
...
...
@@ -2949,18 +2949,18 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
#define ROWS 2
template
<
typename
T
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
const
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
template
<
typename
T
,
int
ITEMS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
const
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp
// 1. Load dataB into register
// 2. Dequantize B
// 3. Fetch data from A and multiply
typedef
cub
::
BlockLoad
<
T
,
256
,
4
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadA
;
typedef
cub
::
BlockLoad
<
T
,
THREADS
,
ITEMS
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadA
;
//__shared__ typename LoadA::TempStorage loada;
typedef
cub
::
BlockLoad
<
T
,
256
,
4
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadB
;
typedef
cub
::
BlockLoad
<
T
,
THREADS
,
ITEMS
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadB
;
//__shared__ typename LoadB::TempStorage loadb;
typedef
cub
::
BlockReduce
<
T
,
256
>
BlockReduce
;
typedef
cub
::
BlockReduce
<
T
,
THREADS
>
BlockReduce
;
// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
...
...
@@ -2971,15 +2971,13 @@ template <typename T> __global__ void gemm_device(int M, int N, int K, T const*
}
temp_storage
;
T
dataA
[
4
];
T
local_B
[
4
];
T
dataA
[
ITEMS
];
T
local_B
[
ITEMS
];
T
local_accC
[
ROWS
];
int
valid_items
=
0
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
col_offset
=
blockIdx
.
x
*
8
;
__shared__
T
tileA
[
ROWS
*
1024
];
__shared__
T
tileA
[
ROWS
*
THREADS
*
ITEMS
];
__shared__
T
accumulatorC
[
ROWS
*
8
];
//#pragma unroll 8
...
...
@@ -2991,17 +2989,17 @@ template <typename T> __global__ void gemm_device(int M, int N, int K, T const*
__syncthreads
();
for
(
int
inner_idx
=
0
;
inner_idx
<
K
;
inner_idx
+=
1024
)
for
(
int
inner_idx
=
0
;
inner_idx
<
K
;
inner_idx
+=
THREADS
*
ITEMS
)
{
valid_items
=
K
-
inner_idx
>
1024
?
1024
:
K
-
inner_idx
;
valid_items
=
K
-
inner_idx
>
THREADS
*
ITEMS
?
THREADS
*
ITEMS
:
K
-
inner_idx
;
int
baserow
=
0
;
for
(
int
row
=
baserow
;
row
<
(
baserow
+
ROWS
)
&&
row
<
N
;
row
++
)
{
LoadA
(
temp_storage
.
loada
).
Load
(
&
(
A
[(
row
*
K
)
+
inner_idx
]),
dataA
,
valid_items
,
0.0
f
);
#pragma unroll
4
for
(
int
k
=
0
;
k
<
4
;
k
++
)
tileA
[
row
*
1024
+
threadIdx
.
x
+
(
k
*
blockDim
.
x
)]
=
dataA
[
k
];
#pragma unroll
ITEMS
for
(
int
k
=
0
;
k
<
ITEMS
;
k
++
)
tileA
[
row
*
THREADS
*
ITEMS
+
threadIdx
.
x
+
(
k
*
THREADS
)]
=
dataA
[
k
];
__syncthreads
();
}
...
...
@@ -3021,16 +3019,16 @@ template <typename T> __global__ void gemm_device(int M, int N, int K, T const*
local_accC
[
k
]
=
0.0
f
;
int
base_idxB
=
ldb
*
colB
;
valid_items
=
K
-
inner_idx
>
1024
?
1024
:
K
-
inner_idx
;
valid_items
=
K
-
inner_idx
>
THREADS
*
ITEMS
?
THREADS
*
ITEMS
:
K
-
inner_idx
;
LoadB
(
temp_storage
.
loadb
).
Load
(
&
(
B
[
base_idxB
+
inner_idx
]),
local_B
,
valid_items
,
0.0
f
);
__syncthreads
();
for
(
int
row
=
0
;
row
<
ROWS
&&
row
<
N
;
row
++
)
{
#pragma unroll
4
for
(
int
k
=
0
;
k
<
4
;
k
++
)
#pragma unroll
ITEMS
for
(
int
k
=
0
;
k
<
ITEMS
;
k
++
)
{
int
idxA
=
row
*
1024
+
threadIdx
.
x
+
(
blockDim
.
x
*
k
);
int
idxA
=
row
*
THREADS
*
ITEMS
+
threadIdx
.
x
+
(
THREADS
*
k
);
local_accC
[
row
]
+=
tileA
[
idxA
]
*
local_B
[
k
];
}
...
...
@@ -3124,7 +3122,10 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// half alpha, half beta);
template
__global__
void
gemm_device
<
float
>(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
float
,
4
,
256
>(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
4
,
256
>(
int
M
,
int
N
,
int
K
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
float
,
8
,
256
>(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
8
,
256
>(
int
M
,
int
N
,
int
K
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
...
...
csrc/kernels.cuh
View file @
f6df4aef
...
...
@@ -138,6 +138,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template
<
size_t
stages_count
/* Pipeline with stages_count stages */
>
__global__
void
with_staging_unified
(
float
const
*
global_in
,
float
*
global_out
,
size_t
size
,
size_t
batch_sz
);
template
<
typename
T
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
const
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
ITEMS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
const
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
#endif
csrc/ops.cu
View file @
f6df4aef
...
...
@@ -689,7 +689,7 @@ template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T
cout
<<
m
<<
endl
;
cout
<<
n
<<
endl
;
cout
<<
k
<<
endl
;
gemm_device
gemm_device
<
T
,
8
,
256
>
<<<
num_blocks
,
dimBlock
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
...
...
@@ -702,6 +702,7 @@ template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T
//==============================================================
template
void
gemm_host
<
float
>(
int
m
,
int
n
,
int
k
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
csrc/pythonInterface.c
View file @
f6df4aef
...
...
@@ -22,6 +22,8 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
void
gemm_host_fp32
(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host
<
float
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
void
gemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host
<
half
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
...
...
@@ -314,6 +316,9 @@ extern "C"
void
cgemm_host_fp32
(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host_fp32
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
void
cgemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host_fp16
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
f6df4aef
...
...
@@ -2352,20 +2352,26 @@ def test_normal_map_tree():
print
(
pivots
)
def
test_cutlass3_gemm
():
A
=
torch
.
rand
(
2
,
4092
).
cuda
()
B
=
torch
.
rand
(
4
*
4092
,
4092
).
cuda
()
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
for
i
in
range
(
2
):
A
=
torch
.
rand
(
2
,
4092
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
rand
(
4
*
4092
,
4092
,
dtype
=
dtype
,
device
=
'cuda'
)
#A = torch.rand(2, 4, dtype=dtype, device='cuda')
#B = torch.rand(4, 4, dtype=dtype, device='cuda')
#print('')
#print(A)
#print(B.t())
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
#print(C1)
#print(C2)
torch
.
testing
.
assert_close
(
C1
,
C2
)
#
torch.testing.assert_close(C1, C2)
def
test_pipeline_func
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment