Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
e01d4e03
Commit
e01d4e03
authored
Apr 30, 2023
by
Tim Dettmers
Browse files
Fixed bank conflicts in non-vector load 422.
parent
c35ed09b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
csrc/kernels.cu
csrc/kernels.cu
+6
-6
No files found.
csrc/kernels.cu
View file @
e01d4e03
...
@@ -3060,11 +3060,11 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3060,11 +3060,11 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
T
local_B
[
64
/
BITS
];
T
local_B
[
64
/
BITS
];
T
local_C
[
8
];
T
local_C
[
8
];
const
int
a_tile_offset
=
32
*
16
;
const
int
a_tile_offset
=
32
*
16
+
16
;
const
int
b_tile_offset
=
16
*
8
;
const
int
b_tile_offset
=
16
*
8
+
16
;
__shared__
T
smem_A
[
WARPS
*
32
*
16
*
2
];
__shared__
T
smem_A
[
WARPS
*
32
*
16
*
2
+
(
16
*
1
)
];
__shared__
T
smem_B
[
WARPS
*
16
*
8
*
2
];
__shared__
T
smem_B
[
WARPS
*
16
*
8
*
2
+
(
16
*
1
)
];
__shared__
T
smem_C
[
WARPS
*
32
*
8
];
__shared__
T
smem_C
[
WARPS
*
32
*
8
];
wmma
::
fragment
<
wmma
::
matrix_a
,
32
,
8
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_a
,
32
,
8
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
...
@@ -3114,8 +3114,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3114,8 +3114,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[
0
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[
0
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[
0
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[
0
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
a2_frag
,
&
(
smem_A
[
32
*
16
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
a2_frag
,
&
(
smem_A
[
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b2_frag
,
&
(
smem_B
[
16
*
8
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
b2_frag
,
&
(
smem_B
[
b_tile_offset
]),
16
);
// 35 mu
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
wmma
::
mma_sync
(
c_frag
,
a2_frag
,
b2_frag
,
c_frag
);
wmma
::
mma_sync
(
c_frag
,
a2_frag
,
b2_frag
,
c_frag
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment