Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c35ed09b
Commit
c35ed09b
authored
Apr 30, 2023
by
Tim Dettmers
Browse files
Double frag 440.
parent
604bb3fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
12 deletions
+17
-12
csrc/kernels.cu
csrc/kernels.cu
+16
-11
tests/test_functional.py
tests/test_functional.py
+1
-1
No files found.
csrc/kernels.cu
View file @
c35ed09b
...
...
@@ -3053,19 +3053,24 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
int
col_offset
=
blockIdx
.
x
*
8
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
half_
warp_id
=
threadIdx
.
x
/
16
;
const
int
half_
warp_lane
=
threadIdx
.
x
%
16
;
T
local_A
[
64
/
BITS
];
T
local_B
[
64
/
BITS
];
T
local_C
[
8
];
__shared__
T
smem_A
[
WARPS
*
32
*
16
];
__shared__
T
smem_B
[
WARPS
*
16
*
8
];
const
int
a_tile_offset
=
32
*
16
;
const
int
b_tile_offset
=
16
*
8
;
__shared__
T
smem_A
[
WARPS
*
32
*
16
*
2
];
__shared__
T
smem_B
[
WARPS
*
16
*
8
*
2
];
__shared__
T
smem_C
[
WARPS
*
32
*
8
];
wmma
::
fragment
<
wmma
::
matrix_a
,
32
,
8
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
32
,
8
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
matrix_a
,
32
,
8
,
16
,
half
,
wmma
::
row_major
>
a2_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
32
,
8
,
16
,
half
,
wmma
::
col_major
>
b2_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
32
,
8
,
16
,
half
>
c_frag
;
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
...
...
@@ -3087,32 +3092,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//int block_idx = 0;
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
16
)
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
32
)
{
int
idx
=
base_idx
+
threadIdx
.
x
;
if
(
threadIdx
.
x
<
16
)
{
if
(
idx
>=
K
)
{
smem_A
[
threadIdx
.
x
]
=
0.0
f
;
smem_B
[
threadIdx
.
x
]
=
0.0
f
;
//
smem_B[threadIdx.x] = 0.0f;
}
else
{
smem_A
[
threadIdx
.
x
]
=
A
[
idx
];
smem_A
[
half_warp_lane
+
(
half_warp_id
*
a_tile_offset
)
]
=
A
[
idx
];
for
(
int
col
=
0
;
col
<
8
;
col
++
)
smem_B
[
threadIdx
.
x
+
(
col
*
16
)]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
smem_B
[
half_warp_lane
+
(
half_warp_id
*
b_tile_offset
)
+
(
col
*
16
)]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
}
}
__syncthreads
();
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[
0
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[
0
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
a2_frag
,
&
(
smem_A
[
32
*
16
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b2_frag
,
&
(
smem_B
[
16
*
8
]),
16
);
// 35 mu
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
wmma
::
mma_sync
(
c_frag
,
a2_frag
,
b2_frag
,
c_frag
);
}
// 129 mu
...
...
tests/test_functional.py
View file @
c35ed09b
...
...
@@ -2373,7 +2373,7 @@ def test_cutlass3_gemm(dtype):
#print(C1)
#print(C2)
torch
.
testing
.
assert_close
(
C1
,
C2
,
atol
=
1e-05
,
rtol
=
0.
0
05
)
torch
.
testing
.
assert_close
(
C1
,
C2
,
atol
=
1e-05
,
rtol
=
0.05
)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment