Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c1bfb210
Commit
c1bfb210
authored
Apr 28, 2023
by
Tim Dettmers
Browse files
First baseline kernel.
parent
9cab14a3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
119 additions
and
33 deletions
+119
-33
bitsandbytes/functional.py
bitsandbytes/functional.py
+4
-4
csrc/kernels.cu
csrc/kernels.cu
+91
-12
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+9
-4
csrc/ops.cuh
csrc/ops.cuh
+1
-1
csrc/pythonInterface.c
csrc/pythonInterface.c
+2
-2
tests/test_functional.py
tests/test_functional.py
+11
-9
No files found.
bitsandbytes/functional.py
View file @
c1bfb210
...
@@ -1429,7 +1429,7 @@ def cutlass3_gemm(
...
@@ -1429,7 +1429,7 @@ def cutlass3_gemm(
m
=
sB
[
1
]
m
=
sB
[
1
]
k
=
sB
[
0
]
k
=
sB
[
0
]
lda
=
B
.
stride
()[
(
1
if
transposed_B
else
0
)
]
lda
=
B
.
stride
()[
0
]
ldc
=
sB
[
1
]
ldc
=
sB
[
1
]
elif
len
(
sB
)
==
3
:
elif
len
(
sB
)
==
3
:
# special case
# special case
...
@@ -1446,7 +1446,7 @@ def cutlass3_gemm(
...
@@ -1446,7 +1446,7 @@ def cutlass3_gemm(
n
=
sA
[
2
]
n
=
sA
[
2
]
k
=
sB
[
0
]
*
sB
[
1
]
k
=
sB
[
0
]
*
sB
[
1
]
lda
=
m
lda
=
n
ldb
=
sA
[
2
]
ldb
=
sA
[
2
]
ldc
=
m
ldc
=
m
...
@@ -1454,7 +1454,7 @@ def cutlass3_gemm(
...
@@ -1454,7 +1454,7 @@ def cutlass3_gemm(
# B^T @ A^T = C^T
# B^T @ A^T = C^T
# [km, nk -> mn]
# [km, nk -> mn]
lda
=
ldb
=
ldc
=
1
#
lda = ldb = ldc = 1
#lda = 1
#lda = 1
#print(m, n, k, lda, ldb, ldc)
#print(m, n, k, lda, ldb, ldc)
is_on_gpu
([
B
,
A
,
out
])
is_on_gpu
([
B
,
A
,
out
])
...
@@ -1466,7 +1466,7 @@ def cutlass3_gemm(
...
@@ -1466,7 +1466,7 @@ def cutlass3_gemm(
ldc
=
ct
.
c_int32
(
ldc
)
ldc
=
ct
.
c_int32
(
ldc
)
alpha
=
ct
.
c_float
(
1.0
)
alpha
=
ct
.
c_float
(
1.0
)
beta
=
ct
.
c_float
(
0.0
)
beta
=
ct
.
c_float
(
0.0
)
lib
.
ccutlass_gemm
(
m
,
n
,
k
,
alpha
,
get_ptr
(
A
),
ld
b
,
get_ptr
(
B
),
ld
a
,
beta
,
get_ptr
(
out
),
ldc
)
lib
.
ccutlass_gemm
(
m
,
n
,
k
,
alpha
,
get_ptr
(
A
),
ld
a
,
get_ptr
(
B
),
ld
b
,
beta
,
get_ptr
(
out
),
ldc
)
return
out
return
out
...
...
csrc/kernels.cu
View file @
c1bfb210
...
@@ -2947,9 +2947,11 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
...
@@ -2947,9 +2947,11 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 9. write outputs to matmul output matrix
//// 9. write outputs to matmul output matrix
//}
//}
#define ROWS 2
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
const
*
A
,
float
const
*
B
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
float
alpha
,
float
beta
)
float
alpha
,
float
beta
)
{
{
...
@@ -2958,29 +2960,106 @@ __global__ void gemm_device(int M, int N, int K,
...
@@ -2958,29 +2960,106 @@ __global__ void gemm_device(int M, int N, int K,
// 2. Dequantize B
// 2. Dequantize B
// 3. Fetch data from A and multiply
// 3. Fetch data from A and multiply
typedef
cub
::
BlockLoad
<
float
,
256
,
1
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadA
;
typedef
cub
::
BlockLoad
<
float
,
256
,
4
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadA
;
__shared__
typename
LoadA
::
TempStorage
loada
;
//__shared__ typename LoadA::TempStorage loada;
float
dataA
[
1
];
typedef
cub
::
BlockLoad
<
float
,
256
,
4
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadB
;
int
valid_items
=
0
;
//__shared__ typename LoadB::TempStorage loadb;
typedef
cub
::
BlockReduce
<
float
,
256
>
BlockReduce
;
// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
__shared__
union
{
typename
BlockReduce
::
TempStorage
reduce
;
typename
LoadB
::
TempStorage
loadb
;
typename
LoadA
::
TempStorage
loada
;
}
temp_storage
;
__shared__
float
[
16
*
256
]
tileA
;
float
dataA
[
4
];
float
local_B
[
4
];
float
local_accC
[
ROWS
];
int
valid_items
=
0
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
col_offset
=
blockIdx
.
x
*
8
;
__shared__
float
tileA
[
ROWS
*
1024
];
__shared__
float
accumulatorC
[
ROWS
*
8
];
//#pragma unroll 8
//for(int i = 0; i < 8; i++)
// tileA[threadIdx.x + (i*256)] = 0.0f;
//__syncthreads();
if
(
threadIdx
.
x
<
64
)
accumulatorC
[
threadIdx
.
x
]
=
0.0
f
;
__syncthreads
();
for
(
int
idx
A
=
0
;
idx
A
<
M
*
K
;
idx
A
+=
256
)
for
(
int
inner_
idx
=
0
;
inner_
idx
<
K
;
inner_
idx
+=
1024
)
{
{
valid_items
=
M
*
K
-
idx
A
>
256
?
256
:
M
*
K
-
idx
A
;
valid_items
=
K
-
inner_
idx
>
1024
?
1024
:
K
-
inner_
idx
;
int
baserow
=
0
;
int
baserow
=
0
;
for
(
int
row
=
baserow
;
row
<
baserow
+
16
&&
row
<
M
+
;
row
++
)
for
(
int
row
=
baserow
;
row
<
(
baserow
+
ROWS
)
&&
row
<
N
;
row
++
)
{
{
LoadA
(
loada
).
Load
(
&
(
A
[(
row
*
lda
)
+
i
]),
dataA
,
valid_items
,
0.0
f
);
LoadA
(
temp_storage
.
loada
).
Load
(
&
(
A
[(
row
*
K
)
+
inner_idx
]),
dataA
,
valid_items
,
0.0
f
);
tileA
[
row
*
256
+
threadIdx
.
x
]
=
dataA
[
0
];
#pragma unroll 4
for
(
int
k
=
0
;
k
<
4
;
k
++
)
tileA
[
row
*
1024
+
threadIdx
.
x
+
(
k
*
blockDim
.
x
)]
=
dataA
[
k
];
__syncthreads
();
__syncthreads
();
}
}
baserow
+=
16
;
baserow
+=
ROWS
;
// load 16 columns from B at a time. B is transposed, so its like loading rows
// each warp loads one row
// each thread loads 128 byte
// col: inner_idx + warp_lane
// row: ldb*(offset + warp_id)
for
(
int
col
=
0
;
col
<
8
&&
(
col_offset
+
col
)
<
M
;
col
++
)
{
int
colB
=
col_offset
+
col
;
for
(
int
k
=
0
;
k
<
ROWS
;
k
++
)
local_accC
[
k
]
=
0.0
f
;
int
base_idxB
=
ldb
*
colB
;
valid_items
=
K
-
inner_idx
>
1024
?
1024
:
K
-
inner_idx
;
LoadB
(
temp_storage
.
loadb
).
Load
(
&
(
B
[
base_idxB
+
inner_idx
]),
local_B
,
valid_items
,
0.0
f
);
__syncthreads
();
for
(
int
row
=
0
;
row
<
ROWS
&&
row
<
N
;
row
++
)
{
#pragma unroll 4
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
idxA
=
row
*
1024
+
threadIdx
.
x
+
(
blockDim
.
x
*
k
);
local_accC
[
row
]
+=
tileA
[
idxA
]
*
local_B
[
k
];
}
local_accC
[
row
]
=
BlockReduce
(
temp_storage
.
reduce
).
Reduce
(
local_accC
[
row
],
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
atomicAdd
(
&
accumulatorC
[
row
*
8
+
col
],
local_accC
[
row
]);
}
}
}
}
for
(
int
row
=
0
;
row
<
ROWS
&&
row
<
N
;
row
++
)
{
int
out_idx
=
ldc
*
row
+
col_offset
;
//if(threadIdx.x < 8)
// if(accumulatorC[row*8 + threadIdx.x] != 0.0)
// printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x);
if
(
threadIdx
.
x
<
8
&&
(
col_offset
+
threadIdx
.
x
)
<
M
)
{
//printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx);
out
[
out_idx
+
threadIdx
.
x
]
=
accumulatorC
[
row
*
8
+
threadIdx
.
x
];
}
}
}
}
...
...
csrc/kernels.cuh
View file @
c1bfb210
...
@@ -140,7 +140,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
...
@@ -140,7 +140,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
const
*
A
,
float
const
*
B
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
float
alpha
,
float
beta
);
float
alpha
,
float
beta
);
...
...
csrc/ops.cu
View file @
c1bfb210
...
@@ -669,8 +669,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
...
@@ -669,8 +669,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
int
threads
=
256
;
int
threads
=
256
;
int
num_blocks
=
(
n
+
(
256
*
batch_size
)
+
1
)
/
(
batch_size
*
256
);
int
num_blocks
=
(
n
+
(
256
*
batch_size
)
+
1
)
/
(
batch_size
*
256
);
printf
(
"%i %i
\n
"
,
num_blocks
,
batch_size
);
with_staging_unified
<
2
><<<
num_blocks
,
threads
>>>
(
A
,
B
,
n
,
batch_size
);
with_staging_unified
<
2
><<<
num_blocks
,
threads
>>>
(
A
,
B
,
n
,
batch_size
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
@@ -680,15 +678,22 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
...
@@ -680,15 +678,22 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
void
gemm_host
(
int
m
,
int
n
,
int
k
,
void
gemm_host
(
int
m
,
int
n
,
int
k
,
float
alpha
,
float
alpha
,
float
const
*
A
,
int
lda
,
float
const
*
A
,
int
lda
,
float
const
*
B
,
int
ldb
,
float
*
B
,
int
ldb
,
float
beta
,
float
beta
,
float
*
C
,
int
ldc
)
float
*
C
,
int
ldc
)
{
{
dim3
dimBlock
(
256
);
dim3
dimBlock
(
256
);
int
num_blocks
=
(
n
+
31
)
/
32
;
int
num_blocks
=
(
m
+
7
)
/
8
;
cout
<<
num_blocks
<<
endl
;
cout
<<
num_blocks
<<
endl
;
cout
<<
lda
<<
endl
;
cout
<<
ldb
<<
endl
;
cout
<<
ldc
<<
endl
;
cout
<<
m
<<
endl
;
cout
<<
n
<<
endl
;
cout
<<
k
<<
endl
;
gemm_device
gemm_device
<<<
num_blocks
,
dimBlock
,
0
,
0
>>>
<<<
num_blocks
,
dimBlock
,
0
,
0
>>>
(
m
,
n
,
k
,
(
m
,
n
,
k
,
...
...
csrc/ops.cuh
View file @
c1bfb210
...
@@ -193,7 +193,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
...
@@ -193,7 +193,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
void
gemm_host
(
int
m
,
int
n
,
int
k
,
void
gemm_host
(
int
m
,
int
n
,
int
k
,
float
alpha
,
float
alpha
,
float
const
*
A
,
int
ldA
,
float
const
*
A
,
int
ldA
,
float
const
*
B
,
int
ldB
,
float
*
B
,
int
ldB
,
float
beta
,
float
beta
,
float
*
C
,
int
ldC
);
float
*
C
,
int
ldC
);
...
...
csrc/pythonInterface.c
View file @
c1bfb210
...
@@ -24,7 +24,7 @@ void
...
@@ -24,7 +24,7 @@ void
cppgemm
(
int
m
,
int
n
,
int
k
,
cppgemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
float
alpha
,
float
const
*
A
,
int
ldA
,
float
const
*
A
,
int
ldA
,
float
const
*
B
,
int
ldB
,
float
*
B
,
int
ldB
,
float
beta
,
float
beta
,
float
*
C
,
int
ldC
)
float
*
C
,
int
ldC
)
{
gemm_host
(
m
,
n
,
k
,
alpha
,
A
,
ldA
,
B
,
ldB
,
beta
,
C
,
ldC
);}
{
gemm_host
(
m
,
n
,
k
,
alpha
,
A
,
ldA
,
B
,
ldB
,
beta
,
C
,
ldC
);}
...
@@ -320,7 +320,7 @@ extern "C"
...
@@ -320,7 +320,7 @@ extern "C"
void
ccutlass_gemm
(
int
m
,
int
n
,
int
k
,
void
ccutlass_gemm
(
int
m
,
int
n
,
int
k
,
float
alpha
,
float
alpha
,
float
const
*
A
,
int
ldA
,
float
const
*
A
,
int
ldA
,
float
const
*
B
,
int
ldB
,
float
*
B
,
int
ldB
,
float
beta
,
float
beta
,
float
*
C
,
int
ldC
)
float
*
C
,
int
ldC
)
{
cppgemm
(
m
,
n
,
k
,
alpha
,
A
,
ldA
,
B
,
ldB
,
beta
,
C
,
ldC
);}
{
cppgemm
(
m
,
n
,
k
,
alpha
,
A
,
ldA
,
B
,
ldB
,
beta
,
C
,
ldC
);}
...
...
tests/test_functional.py
View file @
c1bfb210
...
@@ -2353,17 +2353,19 @@ def test_normal_map_tree():
...
@@ -2353,17 +2353,19 @@ def test_normal_map_tree():
def
test_cutlass3_gemm
():
def
test_cutlass3_gemm
():
#A = torch.rand(2, 2).cuda()
A
=
torch
.
rand
(
2
,
4092
).
cuda
()
#B = torch.rand(2, 2).cuda()
B
=
torch
.
rand
(
4
*
4092
,
4092
).
cuda
()
A
=
torch
.
arange
(
4
).
reshape
(
2
,
2
).
float
().
cuda
().
contiguous
()
B
=
torch
.
ones
(
2
,
2
).
float
().
cuda
()
print
(
''
)
#
print('')
print
(
A
)
#
print(A)
print
(
B
)
#
print(B
.t()
)
C1
=
torch
.
matmul
(
A
,
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
B
)
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
#print(C1)
#print(C2)
torch
.
testing
.
assert_close
(
C1
,
C2
)
def
test_pipeline_func
():
def
test_pipeline_func
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment