Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
cad83994
Commit
cad83994
authored
Apr 28, 2023
by
Tim Dettmers
Browse files
Added bit template.
parent
f3e97ccb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
60 deletions
+45
-60
csrc/kernels.cu
csrc/kernels.cu
+32
-45
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+7
-9
csrc/ops.cuh
csrc/ops.cuh
+1
-1
csrc/pythonInterface.c
csrc/pythonInterface.c
+2
-2
tests/test_functional.py
tests/test_functional.py
+2
-2
No files found.
csrc/kernels.cu
View file @
cad83994
...
@@ -2947,16 +2947,31 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
...
@@ -2947,16 +2947,31 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 9. write outputs to matmul output matrix
//// 9. write outputs to matmul output matrix
//}
//}
#define ROWS 2
template
<
typename
T
,
typename
TCAST
,
int
ITEMS
>
__device__
inline
void
vector_load
(
T
*
local
,
T
*
__restrict__
const
buffer
,
int
idx
,
int
limit_base
,
int
limit
)
template
<
typename
T
,
int
ITEMS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
if
(
limit_base
+
ITEMS
<=
limit
)
reinterpret_cast
<
TCAST
*>
(
local
)[
0
]
=
reinterpret_cast
<
TCAST
*>
(
buffer
)[
idx
/
ITEMS
];
else
{
for
(
int
k
=
0
;
k
<
ITEMS
;
k
++
)
{
if
(
limit_base
+
k
<
limit
)
local
[
k
]
=
buffer
[
idx
+
k
];
else
local
[
k
]
=
0.0
f
;
}
}
}
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
{
typedef
cub
::
BlockReduce
<
T
,
THREADS
>
BlockReduce
;
typedef
cub
::
BlockReduce
<
T
,
THREADS
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
reduce
;
__shared__
typename
BlockReduce
::
TempStorage
reduce
;
int
col_offset
=
blockIdx
.
x
*
8
;
int
col_offset
=
blockIdx
.
x
*
8
;
T
local_A
[
8
];
T
local_A
[
128
/
BITS
];
T
local_B
[
8
];
T
local_B
[
128
/
BITS
];
T
local_C
[
8
];
T
local_C
[
8
];
__shared__
T
smem_C
[
8
];
__shared__
T
smem_C
[
8
];
...
@@ -2970,47 +2985,18 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
...
@@ -2970,47 +2985,18 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
local_C
[
k
]
=
T
(
0
);
local_C
[
k
]
=
T
(
0
);
for
(
int
idx
=
threadIdx
.
x
*
8
;
idx
<
K
;
idx
+=
blockDim
.
x
*
8
)
for
(
int
idx
=
threadIdx
.
x
*
128
/
BITS
;
idx
<
K
;
idx
+=
blockDim
.
x
*
128
/
BITS
)
{
{
vector_load
<
T
,
int4
,
128
/
BITS
>
(
local_A
,
A
,
idx
,
idx
,
K
);
if
(
idx
+
8
<=
K
)
reinterpret_cast
<
float4
(
&
)[
8
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
float4
*>
(
A
)[
idx
/
8
];
else
{
for
(
int
k
=
0
;
k
<
8
;
k
++
)
{
if
(
idx
+
k
<
K
)
local_A
[
k
]
=
A
[
idx
+
k
];
else
local_A
[
k
]
=
0.0
f
;
}
}
for
(
int
col
=
0
;
col
<
8
;
col
++
)
for
(
int
col
=
0
;
col
<
8
;
col
++
)
{
{
int
offset_B
=
(
col_offset
+
col
)
*
ldb
;
int
offset_B
=
(
col_offset
+
col
)
*
ldb
;
if
(
idx
+
8
<=
K
)
vector_load
<
T
,
int4
,
128
/
BITS
>
(
local_B
,
B
,
offset_B
+
idx
,
idx
,
K
);
reinterpret_cast
<
float4
(
&
)[
8
]
>
(
local_B
)[
0
]
=
reinterpret_cast
<
float4
*>
(
B
)[(
offset_B
+
idx
)
/
8
];
else
{
for
(
int
k
=
0
;
k
<
8
;
k
++
)
{
if
(
idx
+
k
<
K
)
local_B
[
k
]
=
B
[(
offset_B
+
idx
)
+
k
];
else
local_B
[
k
]
=
0.0
f
;
}
}
#pragma unroll 8
#pragma unroll 128/BITS
for
(
int
k
=
0
;
k
<
8
;
k
++
)
for
(
int
k
=
0
;
k
<
128
/
BITS
;
k
++
)
{
local_C
[
col
]
+=
local_A
[
k
]
*
local_B
[
k
];
local_C
[
col
]
+=
local_A
[
k
]
*
local_B
[
k
];
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
// printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]);
}
}
}
}
}
...
@@ -3022,9 +3008,11 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3022,9 +3008,11 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
}
}
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll 8
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
for
(
int
k
=
0
;
k
<
8
;
k
++
)
smem_C
[
k
]
=
local_C
[
k
];
smem_C
[
k
]
=
local_C
[
k
];
}
else
if
(
threadIdx
.
x
>=
32
)
else
if
(
threadIdx
.
x
>=
32
)
// early return for unused warps
// early return for unused warps
return
;
return
;
...
@@ -3032,15 +3020,8 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3032,15 +3020,8 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
__syncwarp
();
__syncwarp
();
//for(int k = 0; k < 8; k++)
// if((float)local_C[k] != 0.0f)
// printf("%i %f\n", threadIdx.x, (float)local_C[k]);
if
(
threadIdx
.
x
<
8
&&
col_offset
+
threadIdx
.
x
<
M
)
if
(
threadIdx
.
x
<
8
&&
col_offset
+
threadIdx
.
x
<
M
)
out
[
col_offset
+
threadIdx
.
x
]
=
smem_C
[
threadIdx
.
x
];
out
[
col_offset
+
threadIdx
.
x
]
=
smem_C
[
threadIdx
.
x
];
}
}
//#define ROWS 2
//#define ROWS 2
...
@@ -3217,7 +3198,13 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
...
@@ -3217,7 +3198,13 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// half alpha, half beta);
// half alpha, half beta);
// these are not used and make no sense, but the compiler needs them
template
__global__
void
gemm_device
<
float
,
16
,
128
>(
int
M
,
int
N
,
int
K
,
float
*
__restrict__
const
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
float
,
16
,
128
>(
int
M
,
int
N
,
int
K
,
float
*
__restrict__
const
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
32
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
// these are not used and make no sense, but the compiler needs them
template
__global__
void
gemm_device
<
float
,
32
,
128
>(
int
M
,
int
N
,
int
K
,
float
*
__restrict__
const
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
...
...
csrc/kernels.cuh
View file @
cad83994
...
@@ -138,6 +138,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
...
@@ -138,6 +138,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template
<
size_t
stages_count
/* Pipeline with stages_count stages */
>
template
<
size_t
stages_count
/* Pipeline with stages_count stages */
>
__global__
void
with_staging_unified
(
float
const
*
global_in
,
float
*
global_out
,
size_t
size
,
size_t
batch_sz
);
__global__
void
with_staging_unified
(
float
const
*
global_in
,
float
*
global_out
,
size_t
size
,
size_t
batch_sz
);
template
<
typename
T
,
int
IT
EM
S
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
B
ITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
#endif
#endif
csrc/ops.cu
View file @
cad83994
...
@@ -675,7 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
...
@@ -675,7 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
)
{
{
dim3
dimBlock
(
128
);
dim3
dimBlock
(
128
);
...
@@ -689,20 +689,18 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
...
@@ -689,20 +689,18 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
cout
<<
m
<<
endl
;
cout
<<
m
<<
endl
;
cout
<<
n
<<
endl
;
cout
<<
n
<<
endl
;
cout
<<
k
<<
endl
;
cout
<<
k
<<
endl
;
gemm_device
<
T
,
16
,
128
>
if
(
bits
==
32
)
<<<
num_blocks
,
dimBlock
,
0
,
0
>>>
gemm_device
<
T
,
32
,
128
><<<
num_blocks
,
dimBlock
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
(
m
,
n
,
k
,
else
if
(
bits
==
16
)
A
,
gemm_device
<
T
,
16
,
128
><<<
num_blocks
,
dimBlock
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
B
,
out
,
lda
,
ldb
,
ldc
);
}
}
//==============================================================
//==============================================================
// TEMPLATE DEFINITIONS
// TEMPLATE DEFINITIONS
//==============================================================
//==============================================================
template
void
gemm_host
<
float
>(
int
m
,
int
n
,
int
k
,
float
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
gemm_host
<
float
>(
int
m
,
int
n
,
int
k
,
float
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
csrc/ops.cuh
View file @
cad83994
...
@@ -190,7 +190,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
...
@@ -190,7 +190,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void
matmul4bite
(
half
*
A
,
unsigned
char
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
void
matmul4bite
(
half
*
A
,
unsigned
char
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
void
pipeline_test
(
float
*
A
,
float
*
B
,
size_t
n
,
size_t
batch_size
);
void
pipeline_test
(
float
*
A
,
float
*
B
,
size_t
n
,
size_t
batch_size
);
...
...
csrc/pythonInterface.c
View file @
cad83994
...
@@ -21,9 +21,9 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
...
@@ -21,9 +21,9 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
void
gemm_host_fp32
(
int
M
,
int
N
,
int
K
,
float
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
void
gemm_host_fp32
(
int
M
,
int
N
,
int
K
,
float
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host
<
float
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
{
gemm_host
<
float
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
,
32
);
}
void
gemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
void
gemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host
<
half
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
{
gemm_host
<
half
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
,
16
);
}
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
...
...
tests/test_functional.py
View file @
cad83994
...
@@ -2352,8 +2352,8 @@ def test_normal_map_tree():
...
@@ -2352,8 +2352,8 @@ def test_normal_map_tree():
print
(
pivots
)
print
(
pivots
)
#
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
'fp32'
,
'fp16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
#
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def
test_cutlass3_gemm
(
dtype
):
def
test_cutlass3_gemm
(
dtype
):
for
i
in
range
(
1
):
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment