Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
dede3430
Commit
dede3430
authored
Aug 16, 2022
by
Tim Dettmers
Browse files
Added fused bias in dequant_mm.
parent
111b8764
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
5 deletions
+6
-5
csrc/kernels.cu
csrc/kernels.cu
+2
-1
tests/test_functional.py
tests/test_functional.py
+4
-4
No files found.
csrc/kernels.cu
View file @
dede3430
...
@@ -1951,6 +1951,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
...
@@ -1951,6 +1951,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
float
colStat
=
col
>=
numCols
?
0.0
f
:
colStats
[
col
];
float
colStat
=
col
>=
numCols
?
0.0
f
:
colStats
[
col
];
float
local_biasValue
=
((
bias
==
NULL
)
||
(
col
>=
numCols
))
?
0.0
f
:
__half2float
(
bias
[
col
]);
// no block loads for rows for now -- keep it simple
// no block loads for rows for now -- keep it simple
for
(
int
j
=
threadIdx
.
x
;
j
<
SUBTILE_ROWS
;
j
+=
blockDim
.
x
)
for
(
int
j
=
threadIdx
.
x
;
j
<
SUBTILE_ROWS
;
j
+=
blockDim
.
x
)
{
{
...
@@ -1989,7 +1990,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
...
@@ -1989,7 +1990,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
#pragma unroll ITEMS_PER_THREAD
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
local_output
[
j
]
=
__float2half
(
local_values
[
j
]
*
MM_DEQUANT_CONST
*
local_rowStats
[
j
]
*
colStat
);
local_output
[
j
]
=
__float2half
(
(
local_values
[
j
]
*
MM_DEQUANT_CONST
*
local_rowStats
[
j
]
*
colStat
)
+
local_biasValue
)
;
//absmax_col = fmax(fabsf(local_output[j]), absmax_col);
//absmax_col = fmax(fabsf(local_output[j]), absmax_col);
// we store data in row major
// we store data in row major
...
...
tests/test_functional.py
View file @
dede3430
...
@@ -955,8 +955,8 @@ dim4 = torch.randint(64, 1024, size=(n,)).tolist()
...
@@ -955,8 +955,8 @@ dim4 = torch.randint(64, 1024, size=(n,)).tolist()
# dim1 = [2*1024]
# dim1 = [2*1024]
# dim4 = [2*1024]
# dim4 = [2*1024]
#
dim1 = [4]
#dim1 = [4]
#
dim4 = [4]
#dim4 = [4]
dims
=
(
2
,)
dims
=
(
2
,)
# ldb = list(range(256, 1*1024, 256))
# ldb = list(range(256, 1*1024, 256))
...
@@ -974,7 +974,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
...
@@ -974,7 +974,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
bias
=
None
bias
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
k
):
for
i
in
range
(
1
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"cuda"
)
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"cuda"
)
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
...
@@ -994,7 +994,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
...
@@ -994,7 +994,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
count
=
(
torch
.
isclose
(
C1
,
C4
,
atol
=
0.01
,
rtol
=
0.1
)
==
0
).
sum
().
item
()
count
=
(
torch
.
isclose
(
C1
,
C4
,
atol
=
0.01
,
rtol
=
0.1
)
==
0
).
sum
().
item
()
n
=
C1
.
numel
()
n
=
C1
.
numel
()
p
=
0.06
p
=
0.06
assert
(
count
/
n
<
p
),
f
"error in more than
{
p
}
of elements:
{
count
}
/
{
n
}
=
{
count
/
n
}
"
#
assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
(),
bias
=
bias
)
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
(),
bias
=
bias
)
torch
.
testing
.
assert_allclose
(
C5
,
C4
)
torch
.
testing
.
assert_allclose
(
C5
,
C4
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment