Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
f89ff93e
Commit
f89ff93e
authored
Jul 03, 2023
by
Tim Dettmers
Browse files
Initial 4-bit naive batch size 1, 81 vs 185.
parent
e54d2730
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
240 additions
and
65 deletions
+240
-65
bitsandbytes/functional.py
bitsandbytes/functional.py
+1
-1
csrc/kernels.cu
csrc/kernels.cu
+145
-17
csrc/kernels.cuh
csrc/kernels.cuh
+2
-0
csrc/ops.cu
csrc/ops.cu
+23
-1
csrc/ops.cuh
csrc/ops.cuh
+1
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+6
-0
tests/test_functional.py
tests/test_functional.py
+62
-46
No files found.
bitsandbytes/functional.py
View file @
f89ff93e
...
@@ -1503,7 +1503,7 @@ def cutlass3_gemm(
...
@@ -1503,7 +1503,7 @@ def cutlass3_gemm(
ldc
=
ct
.
c_int32
(
ldc
)
ldc
=
ct
.
c_int32
(
ldc
)
if
B
.
dtype
==
torch
.
uint8
:
if
B
.
dtype
==
torch
.
uint8
:
lib
.
cgemm_4bit_inference
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_4bit_inference
_naive
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
float32
:
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_host_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
lib
.
cgemm_host_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
elif
A
.
dtype
==
torch
.
float16
:
elif
A
.
dtype
==
torch
.
float16
:
...
...
csrc/kernels.cu
View file @
f89ff93e
...
@@ -3088,7 +3088,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
...
@@ -3088,7 +3088,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
}
}
}
}
#define WARPS
5
#define WARPS
3
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
{
...
@@ -3298,15 +3298,15 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3298,15 +3298,15 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
}
}
template
<
typename
T
>
__device__
void
printnonzero
(
T
*
A
,
int
num_values
)
template
<
typename
T
>
__device__
void
printnonzero
(
T
*
A
,
int
num_values
,
const
char
*
strval
)
{
{
for
(
int
i
=
0
;
i
<
num_values
;
i
++
)
for
(
int
i
=
0
;
i
<
num_values
;
i
++
)
if
((
float
)
A
[
i
]
!=
0.0
)
if
((
float
)
A
[
i
]
!=
0.0
)
printf
(
"%i %f
\n
"
,
i
,
(
float
)
A
[
i
]);
printf
(
"
%s
%i %f
\n
"
,
strval
,
i
,
(
float
)
A
[
i
]);
}
}
template
__device__
void
printnonzero
<
float
>(
float
*
A
,
int
num_values
);
template
__device__
void
printnonzero
<
float
>(
float
*
A
,
int
num_values
,
const
char
*
strval
);
template
__device__
void
printnonzero
<
half
>(
half
*
A
,
int
num_values
);
template
__device__
void
printnonzero
<
half
>(
half
*
A
,
int
num_values
,
const
char
*
strval
);
__device__
static
float
nf4_data
[
16
]
=
{
-
1.0
,
-
0.6961928009986877
,
-
0.5250730514526367
,
-
0.39491748809814453
,
-
0.28444138169288635
,
-
0.18477343022823334
,
-
0.09105003625154495
,
0.0
,
0.07958029955625534
,
0.16093020141124725
,
0.24611230194568634
,
0.33791524171829224
,
0.44070982933044434
,
0.5626170039176941
,
0.7229568362236023
,
1.0
};
__device__
static
float
nf4_data
[
16
]
=
{
-
1.0
,
-
0.6961928009986877
,
-
0.5250730514526367
,
-
0.39491748809814453
,
-
0.28444138169288635
,
-
0.18477343022823334
,
-
0.09105003625154495
,
0.0
,
0.07958029955625534
,
0.16093020141124725
,
0.24611230194568634
,
0.33791524171829224
,
0.44070982933044434
,
0.5626170039176941
,
0.7229568362236023
,
1.0
};
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
...
@@ -3315,6 +3315,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3315,6 +3315,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
using
namespace
nvcuda
;
using
namespace
nvcuda
;
int
col_offset
=
blockIdx
.
x
*
32
;
int
col_offset
=
blockIdx
.
x
*
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_idx
=
threadIdx
.
x
%
32
;
const
int
half_warp_id
=
threadIdx
.
x
/
16
;
const
int
half_warp_id
=
threadIdx
.
x
/
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
...
@@ -3324,23 +3325,30 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3324,23 +3325,30 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 16
#pragma unroll 16
for
(
int
i
=
0
;
i
<
16
;
i
++
)
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
quant_map
[
i
]
=
nf4_data
[
i
];
//__shared__ T quant_map[16*160];
T
local_A
[
2
];
T
local_A
[
2
];
T
local_B
[
64
];
T
local_B
[
64
];
unsigned
char
local_B_4bit
[
32
];
unsigned
char
local_B_4bit
[
32
];
const
int
a_tile_offset
=
16
;
const
int
a_tile_offset
=
16
;
const
int
b_tile_offset
=
(
16
*
32
+
16
);
const
int
b_tile_offset
=
(
16
*
32
+
16
);
__shared__
T
smem_A
[
8
*
16
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_A
[
8
*
16
+
(
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_B
[
2
*
batch_size_warps
*
16
*
32
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_B
[
2
*
batch_size_warps
*
16
*
32
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
//
__shared__ T smem_C[8*32];
__shared__
T
smem_C
[
8
*
32
];
wmma
::
fragment
<
wmma
::
matrix_a
,
8
,
32
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_a
,
8
,
32
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
8
,
32
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
8
,
32
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
8
,
32
,
16
,
half
>
c_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
8
,
32
,
16
,
half
>
c_frag
;
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
for
(
int
i
=
threadIdx
.
x
;
i
<
(
8
*
32
);
i
+=
blockDim
.
x
)
smem_C
[
i
]
=
0.0
f
;
__syncthreads
();
int
ticktock
=
0
;
int
ticktock
=
0
;
int
idx
=
0
+
threadIdx
.
x
;
int
idx
=
0
+
threadIdx
.
x
;
int
loaded_values
=
0
;
int
loaded_values
=
0
;
...
@@ -3366,8 +3374,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3366,8 +3374,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 64
#pragma unroll 64
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
{
{
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
1.0
f
);
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
1.0
f
);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B
[
col
]
=
quant_map
[
160
*
(
local_B_4bit
[
col
/
2
]
>>
4
)
+
warp_idx
]
*
T
(
17.0
);
local_B
[
col
+
1
]
=
quant_map
[
160
*
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
+
warp_idx
]
*
T
(
17.0
);
}
}
}
}
...
@@ -3391,13 +3408,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3391,13 +3408,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
}
}
ticktock
=
ticktock
==
0
?
1
:
0
;
ticktock
=
ticktock
==
0
?
1
:
0
;
//if(threadIdx.x == 0)
//printf("aa %i %i\n", idx, loaded_values);
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for
(
int
base_idx
=
blockDim
.
x
-
32
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
for
(
int
base_idx
=
blockDim
.
x
-
32
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
{
{
idx
=
base_idx
+
threadIdx
.
x
;
idx
=
base_idx
+
threadIdx
.
x
;
//if(threadIdx.x == 0)
//printf("%i %i\n", idx, loaded_values);
__syncthreads
();
//
__syncthreads();
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
{
{
if
(
loaded_values
==
0
)
if
(
loaded_values
==
0
)
...
@@ -3425,11 +3446,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3425,11 +3446,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 64
#pragma unroll 64
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
{
{
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
absidx
);
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
absidx
);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B
[
col
]
=
quant_map
[(
local_B_4bit
[
col
/
2
]
>>
4
)]
*
T
(
absidx
);
local_B
[
col
+
1
]
=
quant_map
[(
local_B_4bit
[
col
/
2
]
&
0x0F
)]
*
T
(
absidx
);
}
}
//printnonzero<T>(local_B, 128, "");
}
}
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
...
@@ -3463,6 +3490,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3463,6 +3490,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
}
}
__syncthreads
();
__syncthreads
();
//if(threadIdx.x == 0)
//{
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
//}
if
(
warp_id
!=
(
WARPS
-
1
)){
return
;
}
if
(
warp_id
!=
(
WARPS
-
1
)){
return
;
}
// only warp_id == (WARPS-1) from here
// only warp_id == (WARPS-1) from here
int
warp_lane
=
threadIdx
.
x
%
32
;
int
warp_lane
=
threadIdx
.
x
%
32
;
...
@@ -3470,6 +3502,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3470,6 +3502,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
ticktock
=
ticktock
==
0
?
1
:
0
;
ticktock
=
ticktock
==
0
?
1
:
0
;
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
{
{
//if(warp_lane == 0)
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[(
ticktock
*
batch_size_warps
+
k
)
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[(
ticktock
*
batch_size_warps
+
k
)
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[(
ticktock
*
batch_size_warps
+
k
)
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[(
ticktock
*
batch_size_warps
+
k
)
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
...
@@ -3477,14 +3511,101 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3477,14 +3511,101 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
// 129 mu
// 129 mu
if
(
warp_id
==
(
WARPS
-
1
))
if
(
warp_id
==
(
WARPS
-
1
))
wmma
::
store_matrix_sync
(
&
(
smem_
A
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
wmma
::
store_matrix_sync
(
&
(
smem_
C
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
printnonzero
<
T
>
(
smem_
A
,
32
);
//
printnonzero<T>(smem_
C
, 32
, ""
);
if
(
col_offset
+
warp_lane
<
M
)
if
(
col_offset
+
warp_lane
<
M
)
out
[
col_offset
+
warp_lane
]
=
smem_A
[
warp_lane
];
out
[
col_offset
+
warp_lane
]
=
smem_C
[
warp_lane
];
}
#define num_values_4bit 16
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
// per threadblock:
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
typedef
cub
::
WarpReduce
<
T
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
4
];
const
int
warp_idx
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
row_B
=
4
*
blockIdx
.
x
+
warp_idx
;
T
local_C
=
T
(
0
);
T
quant_map
[
16
];
#pragma unroll 16
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
unsigned
char
local_B_4bit
[
num_values_4bit
/
2
];
T
local_B
[
num_values_4bit
];
// need to increase occupancy by splitting the rows, but can be done later
// A: [1, K]
// B: [N, K]
for
(
int
inner_idx
=
warp_lane
*
num_values_4bit
;
inner_idx
<
K
;
inner_idx
+=
32
*
num_values_4bit
)
{
int
offset_B
=
ldb
*
row_B
+
(
inner_idx
/
2
);
int
absidx
=
(
2
*
offset_B
)
/
blocksize
;
T
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
//printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B);
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
/
2
;
k
++
)
{
if
((
inner_idx
/
2
)
<
K
&&
row_B
<
M
)
local_B_4bit
[
k
]
=
B
[
offset_B
+
k
];
else
local_B_4bit
[
k
]
=
0b01110111
;
}
//if(row_B < M)
//{
// if((inner_idx/num_values_4bit) < K)
// reinterpret_cast<int4*>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[offset_B/(num_values_4bit/2)];
// else
// {
// for(int k = 0; k < num_values_4bit/2; k++)
// {
// if((inner_idx/2) < K && row_B < M)
// local_B_4bit[k] = B[offset_B + k];
// else
// local_B_4bit[k] = 0b01110111;
// }
// }
//}
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
}
//printnonzero<T>(local_B, 4, "B values: ");
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_C
+=
A
[
inner_idx
+
k
]
*
local_B
[
k
];
}
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
if
(
row_B
<
M
&&
warp_lane
==
0
)
out
[
row_B
]
=
local_C
;
}
}
//#define ROWS 2
//#define ROWS 2
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
//{
//{
...
@@ -3647,8 +3768,15 @@ template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * _
...
@@ -3647,8 +3768,15 @@ template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * _
template
__global__
void
gemm_device
<
half
,
16
,
64
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
64
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
kgemm_4bit_inference
<
half
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
...
...
csrc/kernels.cuh
View file @
f89ff93e
...
@@ -106,6 +106,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
...
@@ -106,6 +106,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPercentileClipping
(
T
*
__restrict__
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPercentileClipping
(
T
*
__restrict__
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
__global__
void
kHistogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
const
int
maxidx1
,
const
int
n
);
__global__
void
kHistogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
const
int
maxidx1
,
const
int
n
);
...
@@ -124,6 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
...
@@ -124,6 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
__global__
void
kfunc
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
template
<
typename
T
,
int
FUNC
>
__global__
void
kfunc
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/ops.cu
View file @
f89ff93e
...
@@ -723,7 +723,28 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
...
@@ -723,7 +723,28 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//cout << m << endl;
//cout << m << endl;
//cout << n << endl;
//cout << n << endl;
//cout << k << endl;
//cout << k << endl;
kgemm_4bit_inference
<
T
,
160
><<<
num_blocks
,
160
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
kgemm_4bit_inference
<
T
,
96
><<<
num_blocks
,
96
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
int
num_blocks
=
(
m
+
3
)
/
4
;
cout
<<
num_blocks
<<
endl
;
//cout << lda << endl;
//cout << ldb << endl;
//cout << ldc << endl;
//cout << m << endl;
//cout << n << endl;
//cout << k << endl;
kgemm_4bit_inference_naive
<
T
,
128
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
}
...
@@ -747,6 +768,7 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
...
@@ -747,6 +768,7 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
template
void
func
<
float
,
_MUL
>(
float
*
A
,
float
*
B
,
float
value
,
long
n
);
template
void
func
<
float
,
_MUL
>(
float
*
A
,
float
*
B
,
float
value
,
long
n
);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
csrc/ops.cuh
View file @
f89ff93e
...
@@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
...
@@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/pythonInterface.c
View file @
f89ff93e
...
@@ -28,6 +28,9 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
...
@@ -28,6 +28,9 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
...
@@ -345,6 +348,9 @@ extern "C"
...
@@ -345,6 +348,9 @@ extern "C"
void
cgemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
void
cgemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
*
cget_managed_ptr
(
size_t
bytes
)
void
*
cget_managed_ptr
(
size_t
bytes
)
{
{
void
*
ptr
;
void
*
ptr
;
...
...
tests/test_functional.py
View file @
f89ff93e
...
@@ -1773,17 +1773,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1773,17 +1773,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
1
batch_size
=
32
seqdim
=
1
seqdim
=
512
+
256
values
=
[]
values
=
[]
#values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 2560, 4*2560))
values
.
append
((
batch_size
,
seqdim
,
4096
,
4
*
4096
))
#
values.append((batch_size, seqdim, 4096, 4*4096))
values
.
append
((
batch_size
,
seqdim
,
5120
,
4
*
5120
))
#
values.append((batch_size, seqdim, 5120, 4*5120))
values
.
append
((
batch_size
,
seqdim
,
6656
,
4
*
6656
))
#
values.append((batch_size, seqdim, 6656, 4*6656))
values
.
append
((
batch_size
,
seqdim
,
8192
,
4
*
8192
))
values
.
append
((
batch_size
,
seqdim
,
8192
,
4
*
8192
))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
#values.append((batch_size, seqdim, 12288, 4*12288))
...
@@ -1827,19 +1827,19 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1827,19 +1827,19 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
t0
=
time
.
time
()
#
t0 = time.time()
for
i
in
range
(
iters
):
#
for i in range(iters):
bnb
.
matmul_4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state
)
#
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
t0
=
time
.
time
()
#
t0 = time.time()
for
i
in
range
(
iters
):
#
for i in range(iters):
bnb
.
matmul_4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state_c
)
#
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
print
(
f
"bnb fp4 + compressed stats: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
...
@@ -1901,21 +1901,21 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1901,21 +1901,21 @@ def test_bench_matmul(batch, seq, model, hidden):
#torch.cuda.synchronize()
#torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit
(
A
)
#
linear8bit(A)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
t0
=
time
.
time
()
#
t0 = time.time()
for
i
in
range
(
iters
):
#
for i in range(iters):
linear8bit
(
A
)
#
linear8bit(A)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
print
(
f
"bnb linear8bitlt (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linearMixedBit
(
A
)
#
linearMixedBit(A)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
t0
=
time
.
time
()
#
t0 = time.time()
for
i
in
range
(
iters
):
#
for i in range(iters):
linearMixedBit
(
A
)
#
linearMixedBit(A)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
print
(
f
"bnb linear8bitlt with threshold (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train(A)
#linear8bit_train(A)
#torch.cuda.synchronize()
#torch.cuda.synchronize()
...
@@ -2411,10 +2411,14 @@ def test_cutlass3_gemm(dtype):
...
@@ -2411,10 +2411,14 @@ def test_cutlass3_gemm(dtype):
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_gemm_4bit
(
dtype
):
def
test_gemm_4bit
(
dtype
):
print
(
''
)
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
#for dim in [32]:
for
dim
in
[
32
]:
for
dim
in
[
4096
]:
#for dim in [5120]:
#for dim in [6656]:
#for dim in [128]:
errs
=
[]
errs
=
[]
relerrs
=
[]
relerrs
=
[]
max_err
=
0
max_err
=
0
...
@@ -2424,24 +2428,36 @@ def test_gemm_4bit(dtype):
...
@@ -2424,24 +2428,36 @@ def test_gemm_4bit(dtype):
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
+
2
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
2
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print('')
#print(A)
#print(A)
#print(B.t())
#print(B.t())
#A[:, :-1] = 0
#A[:, :-1] = 0
#B[:, :-1] = 0
#B[:, :-1] = 0
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
qB
,
state
=
F
.
quantize_nf4
(
B
)
qB
,
state
=
F
.
quantize_nf4
(
B
)
F
.
dequantize_nf4
(
qB
,
state
)
F
.
dequantize_nf4
(
qB
,
state
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
#
C3 = torch.matmul(A, B.t())
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
print
(
C1
)
#print(state)
print
(
C2
)
#print(qB)
#print('')
#print(A)
#print(B)
#print('='*89)
#print(C1)
#print(C2)
#print(C3)
#print(C1.shape, C2.shape)
#print(C1.shape, C2.shape)
...
@@ -2455,7 +2471,7 @@ def test_gemm_4bit(dtype):
...
@@ -2455,7 +2471,7 @@ def test_gemm_4bit(dtype):
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
print
(
err
)
#
print(err)
errs
.
append
(
err
)
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
relerrs
.
append
(
relerr
)
...
@@ -2463,20 +2479,20 @@ def test_gemm_4bit(dtype):
...
@@ -2463,20 +2479,20 @@ def test_gemm_4bit(dtype):
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
print
(
''
)
print
(
''
)
print
(
i
,
err
,
relerr
)
print
(
i
,
err
,
relerr
)
print
(
A
.
flatten
()[
-
6
:])
#
print(A.flatten()[-6:])
print
(
B
.
flatten
()[
-
6
:])
#
print(B.flatten()[-6:])
out
=
A
.
flatten
()[
-
6
:]
*
B
.
flatten
()[
-
6
:]
#
out = A.flatten()[-6:]*B.flatten()[-6:]
print
(
out
)
#
print(out)
print
(
out
[:
-
1
].
sum
())
#
print(out[:-1].sum())
print
(
'='
*
80
)
print
(
'='
*
80
)
print
(
C1
.
flatten
()[
-
6
:])
#
print(C1.flatten()[-6:])
print
(
C2
.
flatten
()[
-
6
:])
#
print(C2.flatten()[-6:])
#assert False, 'ERROR'
#assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
#
print(c/math.sqrt(dim))
print
(
c
/
math
.
sqrt
(
dim
))
print
(
''
)
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment