Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ed305f6b
Commit
ed305f6b
authored
Sep 28, 2023
by
Umang Yadav
Browse files
formatting
parent
9f4e3544
Changes
45
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
106 deletions
+111
-106
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+1
-1
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+6
-6
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
...rary/include/ck/host/device_batched_gemm_softmax_gemm.hpp
+75
-74
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
.../src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
+17
-16
library/src/jit_library/src/device_gemm_multiple_d.cpp
library/src/jit_library/src/device_gemm_multiple_d.cpp
+12
-9
No files found.
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
ed305f6b
include/ck/utility/magic_division.hpp
View file @
ed305f6b
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
View file @
ed305f6b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -99,7 +99,8 @@ struct Problem
static
const
std
::
size_t
B1BlockLdsAddExtraN_idx
=
52
;
static
const
std
::
size_t
CShuffleMXdlPerWavePerShuffle_idx
=
53
;
static
const
std
::
size_t
CShuffleNXdlPerWavePerShuffle_idx
=
54
;
static
const
std
::
size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx
=
55
;
static
const
std
::
size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx
=
55
;
static
const
std
::
size_t
CBlockTransferScalarPerVector_NWaveNPerXdl_idx
=
56
;
static
const
std
::
size_t
MaskOutUpperTriangle_idx
=
57
;
};
...
...
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
View file @
ed305f6b
...
...
@@ -80,7 +80,8 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
n1_per_block
=
std
::
stoi
(
n1_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
O
,
m_per_block
,
n1_per_block
);
params
[
GEMMSpecialization_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
O
,
m_per_block
,
n_per_block
,
k_per_block
,
n1_per_block
);
params
[
GEMMSpecialization_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
O
,
m_per_block
,
n_per_block
,
k_per_block
,
n1_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
...
...
library/src/jit_library/src/device_gemm_multiple_d.cpp
View file @
ed305f6b
...
...
@@ -101,18 +101,21 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
EDataType
==
DataType
::
Half
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
if
(
EDataType
==
DataType
::
Half
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
{
params
[
params
.
size
()
-
3
]
=
"8"
;
}
if
(
EDataType
==
DataType
::
Float
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
if
(
EDataType
==
DataType
::
Float
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
if
(
EDataType
==
DataType
::
Int32
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Int32
;
}))
if
(
EDataType
==
DataType
::
Int32
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Int32
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
...
...
@@ -141,7 +144,7 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
if
(
params
.
back
().
find
(
"v2"
)
!=
std
::
string
::
npos
and
K
%
k_per_block
!=
0
)
if
(
params
.
back
().
find
(
"v2"
)
!=
std
::
string
::
npos
and
K
%
k_per_block
!=
0
)
str
=
""
;
return
Solution
{
str
,
block_size
,
grid_size
};
...
...
@@ -159,7 +162,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
auto
solution
=
MakeSolution
(
i
,
arch
);
if
(
solution
.
template_str
!=
""
)
if
(
solution
.
template_str
!=
""
)
solutions
.
push_back
(
solution
);
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment