Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
830dff7a
Commit
830dff7a
authored
May 24, 2023
by
Alan Turner
Browse files
Formatting
parent
d46c7224
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
42 deletions
+42
-42
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+10
-10
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
...gets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
+1
-2
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+31
-30
No files found.
src/targets/gpu/compile_hip_code_object.cpp
View file @
830dff7a
...
@@ -167,19 +167,19 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
...
@@ -167,19 +167,19 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
auto
path
=
fs
::
path
{
"migraphx"
}
/
"kernels"
/
name
;
auto
path
=
fs
::
path
{
"migraphx"
}
/
"kernels"
/
name
;
return
src_file
{
path
,
c
};
return
src_file
{
path
,
c
};
});
});
if
(
not
options
.
embedded_headers
.
empty
())
if
(
not
options
.
embedded_headers
.
empty
())
{
{
std
::
transform
(
options
.
embedded_headers
.
begin
(),
std
::
transform
(
options
.
embedded_headers
.
begin
(),
options
.
embedded_headers
.
end
(),
options
.
embedded_headers
.
end
(),
std
::
back_inserter
(
srcs
),
std
::
back_inserter
(
srcs
),
[](
auto
&&
p
)
{
[](
auto
&&
p
)
{
auto
&&
name
=
p
.
first
;
auto
&&
name
=
p
.
first
;
auto
&&
c
=
p
.
second
;
auto
&&
c
=
p
.
second
;
auto
path
=
fs
::
path
{
"migraphx"
}
/
"kernels"
/
name
;
auto
path
=
fs
::
path
{
"migraphx"
}
/
"kernels"
/
name
;
return
src_file
{
path
,
c
};
return
src_file
{
path
,
c
};
});
});
}
}
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())});
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())});
auto
args_hpp
=
auto
args_hpp
=
...
...
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
View file @
830dff7a
...
@@ -42,8 +42,7 @@ struct hip_compile_options
...
@@ -42,8 +42,7 @@ struct hip_compile_options
std
::
string
kernel_name
=
"kernel"
;
std
::
string
kernel_name
=
"kernel"
;
std
::
string
params
=
""
;
std
::
string
params
=
""
;
std
::
vector
<
shape
>
virtual_inputs
=
{};
std
::
vector
<
shape
>
virtual_inputs
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
embedded_headers
;
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
embedded_headers
;
/**
/**
* @brief Set the launch parameters but allow v to override the values
* @brief Set the launch parameters but allow v to override the values
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
830dff7a
...
@@ -40,7 +40,6 @@
...
@@ -40,7 +40,6 @@
#include "ck/include/device_gemm_multiple_d.hpp"
#include "ck/include/device_gemm_multiple_d.hpp"
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -79,8 +78,6 @@ __global__ void ${kernel}(${params})
...
@@ -79,8 +78,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"
;
)__migraphx__"
;
static
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
static
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
template
<
class
F
,
class
Action
>
template
<
class
F
,
class
Action
>
...
@@ -237,41 +234,46 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -237,41 +234,46 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
n
=
c_shape
.
lens
().
back
();
auto
n
=
c_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
const
bool
transA
=
transposed_matrix
(
a_shape
);
const
bool
transA
=
transposed_matrix
(
a_shape
);
const
bool
transB
=
transposed_matrix
(
b_shape
);
const
bool
transB
=
transposed_matrix
(
b_shape
);
const
bool
transE
=
transposed_matrix
(
c_shape
);
const
bool
transE
=
transposed_matrix
(
c_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
auto
e_type
=
get_type
(
c_shape
);
const
auto
e_type
=
get_type
(
c_shape
);
std
::
vector
<
bool
>
ds_layout
;
std
::
vector
<
bool
>
ds_layout
;
std
::
transform
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
std
::
back_inserter
(
ds_layout
),
[](
const
auto
&
i
){
return
transposed_matrix
(
i
);
});
std
::
transform
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
std
::
back_inserter
(
ds_layout
),
[](
const
auto
&
i
)
{
return
transposed_matrix
(
i
);
});
std
::
vector
<
std
::
string
>
ds_type
;
std
::
vector
<
std
::
string
>
ds_type
;
std
::
transform
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
std
::
back_inserter
(
ds_type
),
[](
const
auto
&
i
){
return
get_type
(
i
);
});
std
::
transform
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
std
::
back_inserter
(
ds_type
),
[](
const
auto
&
i
)
{
return
get_type
(
i
);
});
std
::
string
ck_passthrough
=
"ck_passthrough"
;
std
::
string
ck_passthrough
=
"ck_passthrough"
;
std
::
string
cde_op
=
ck_passthrough
;
std
::
string
cde_op
=
ck_passthrough
;
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
if
(
v
.
contains
(
"post"
))
{
{
cde_op
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
cde_op
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
}
}
auto
problem
=
auto
problem
=
ck
::
tensor_operation
::
device
::
device_gemm_multiple_d
::
Problem
{
ck
::
tensor_operation
::
device
::
device_gemm_multiple_d
::
static_cast
<
ck
::
index_t
>
(
m
),
Problem
{
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
k
),
transA
,
transA
,
transB
,
transB
,
transE
,
transE
,
ds_layout
,
ds_layout
,
a_type
,
a_type
,
b_type
,
b_type
,
e_type
,
e_type
,
ds_type
,
ds_type
,
ck_passthrough
,
ck_passthrough
,
ck_passthrough
,
ck_passthrough
,
cde_op
};
cde_op
};
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
ck_headers
=
problem
.
GetHeaders
();
const
auto
ck_headers
=
problem
.
GetHeaders
();
...
@@ -281,7 +283,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -281,7 +283,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const
auto
blocks_per_batch
=
solution
.
grid_size
;
const
auto
blocks_per_batch
=
solution
.
grid_size
;
const
auto
block_size
=
solution
.
block_size
;
const
auto
block_size
=
solution
.
block_size
;
hip_compile_options
options
;
hip_compile_options
options
;
options
.
embedded_headers
=
ck_headers
;
options
.
embedded_headers
=
ck_headers
;
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment