Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
15bf3d62
Commit
15bf3d62
authored
Aug 30, 2018
by
Paul
Browse files
Merge branch 'master' into memory_coloring
parents
7fa4d978
d2778c9e
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
595 additions
and
77 deletions
+595
-77
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+66
-0
src/include/migraph/functional.hpp
src/include/migraph/functional.hpp
+49
-1
src/include/migraph/fwd_conv_batchnorm_rewrite.hpp
src/include/migraph/fwd_conv_batchnorm_rewrite.hpp
+19
-0
src/include/migraph/generate.hpp
src/include/migraph/generate.hpp
+23
-1
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+5
-0
src/include/migraph/tracer.hpp
src/include/migraph/tracer.hpp
+1
-8
src/include/migraph/verify_args.hpp
src/include/migraph/verify_args.hpp
+51
-0
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+6
-2
src/onnx/cifar10.cpp
src/onnx/cifar10.cpp
+107
-0
src/onnx/mnist.cpp
src/onnx/mnist.cpp
+11
-14
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+1
-1
src/onnx/softmax.hpp
src/onnx/softmax.hpp
+14
-0
src/onnx/verify_onnx.cpp
src/onnx/verify_onnx.cpp
+2
-16
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/device/add.cpp
src/targets/gpu/device/add.cpp
+15
-0
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+2
-3
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
+10
-2
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+205
-23
src/targets/gpu/device/include/migraph/gpu/device/tensor.hpp
src/targets/gpu/device/include/migraph/gpu/device/tensor.hpp
+6
-6
No files found.
src/CMakeLists.txt
View file @
15bf3d62
...
@@ -3,6 +3,7 @@ add_library(migraph
...
@@ -3,6 +3,7 @@ add_library(migraph
auto_contiguous.cpp
auto_contiguous.cpp
dead_code_elimination.cpp
dead_code_elimination.cpp
eliminate_contiguous.cpp
eliminate_contiguous.cpp
fwd_conv_batchnorm_rewrite.cpp
env.cpp
env.cpp
generate.cpp
generate.cpp
program.cpp
program.cpp
...
...
src/fwd_conv_batchnorm_rewrite.cpp
0 → 100644
View file @
15bf3d62
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace
migraph
{
void
fwd_conv_batchnorm_rewrite
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
op
.
name
()
!=
"batch_norm_inference"
)
continue
;
if
(
not
std
::
all_of
(
ins
->
arguments
.
begin
()
+
1
,
ins
->
arguments
.
end
(),
[](
auto
arg
)
{
return
arg
->
op
.
name
()
==
"@literal"
;
}))
continue
;
auto
conv_ins
=
ins
->
arguments
[
0
];
if
(
conv_ins
->
op
.
name
()
!=
"convolution"
)
continue
;
if
(
conv_ins
->
arguments
[
1
]
->
op
.
name
()
!=
"@literal"
)
continue
;
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
arguments
[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
arguments
[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
arguments
[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
arguments
[
4
]
->
get_literal
();
// Get epsilon
auto
bn_op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
arguments
[
1
]
->
get_literal
();
// Get convolution op
auto
conv_op
=
conv_ins
->
op
;
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
argument
new_weights
{
weights
.
get_shape
()};
argument
new_bias
{
bias
.
get_shape
()};
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
auto
weights2
,
auto
gamma2
,
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
new_weights2
,
auto
new_bias2
)
{
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
(
k
)
/
std
::
sqrt
(
variance2
(
k
)
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
(
c
)
=
bias2
(
c
)
-
(
mean2
(
c
)
/
std
::
sqrt
(
variance2
(
c
)
+
epsilon
));
});
});
// Replace convolution instruction with updated weights
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
arguments
[
0
],
l_weights
});
auto
b
=
p
.
insert_instruction
(
ins
,
broadcast
{
1
},
c
,
l_bias
);
p
.
replace_instruction
(
ins
,
add
{},
{
c
,
b
});
}
}
}
// namespace migraph
src/include/migraph/functional.hpp
View file @
15bf3d62
...
@@ -5,6 +5,14 @@
...
@@ -5,6 +5,14 @@
namespace
migraph
{
namespace
migraph
{
struct
swallow
{
template
<
class
...
Ts
>
constexpr
swallow
(
Ts
&&
...)
{
}
};
namespace
detail
{
namespace
detail
{
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
...
@@ -19,8 +27,48 @@ struct fix_f
...
@@ -19,8 +27,48 @@ struct fix_f
}
}
};
};
template
<
std
::
size_t
...>
struct
seq
{
using
type
=
seq
;
};
template
<
class
,
class
>
struct
merge_seq
;
template
<
std
::
size_t
...
Xs
,
std
::
size_t
...
Ys
>
struct
merge_seq
<
seq
<
Xs
...
>
,
seq
<
Ys
...
>>
:
seq
<
Xs
...,
(
sizeof
...(
Xs
)
+
Ys
)...
>
{
};
template
<
std
::
size_t
N
>
struct
gens
:
merge_seq
<
typename
gens
<
N
/
2
>::
type
,
typename
gens
<
N
-
N
/
2
>::
type
>
{
};
template
<
>
struct
gens
<
0
>
:
seq
<>
{
};
template
<
>
struct
gens
<
1
>
:
seq
<
0
>
{
};
template
<
class
F
,
std
::
size_t
...
Ns
>
constexpr
void
repeat_c_impl
(
F
f
,
seq
<
Ns
...
>
)
{
swallow
{(
f
(
std
::
integral_constant
<
std
::
size_t
,
Ns
>
{}),
0
)...};
}
}
// namespace detail
}
// namespace detail
template
<
std
::
size_t
N
,
class
F
>
constexpr
void
repeat_c
(
F
f
)
{
detail
::
repeat_c_impl
(
f
,
detail
::
gens
<
N
>
{});
}
/// Implements a fix-point combinator
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
...
@@ -35,7 +83,7 @@ auto fix(F f)
...
@@ -35,7 +83,7 @@ auto fix(F f)
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
make_sequence
(
Ts
...
xs
)
auto
pack
(
Ts
...
xs
)
{
{
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
}
...
...
src/include/migraph/fwd_conv_batchnorm_rewrite.hpp
0 → 100644
View file @
15bf3d62
#ifndef MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace
migraph
{
struct
program
;
struct
fwd_conv_batchnorm_rewrite
{
std
::
string
name
()
const
{
return
"fwd_conv_batchnorm_rewrite"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/include/migraph/generate.hpp
View file @
15bf3d62
...
@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z)
...
@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z)
{
{
if
(
z
==
0
)
if
(
z
==
0
)
return
0
;
return
0
;
return
(
2.0
/
z
)
-
1.0
;
const
auto
max
=
32768
;
const
double
range
=
max
/
2
;
// NOLINT
double
result
=
(
z
%
max
)
/
range
;
result
-=
1
;
return
result
;
}
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_signed
<
T
>{}
and
not
std
::
is_floating_point
<
T
>
{})
>
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_signed
<
T
>{}
and
not
std
::
is_floating_point
<
T
>
{})
>
...
@@ -54,11 +58,29 @@ struct xorshf96_generator
...
@@ -54,11 +58,29 @@ struct xorshf96_generator
}
}
};
};
template
<
class
T
>
struct
xorshift_generator
{
unsigned
long
x
;
xorshift_generator
(
unsigned
long
seed
=
0
)
:
x
(
521288629ULL
^
seed
)
{}
constexpr
T
operator
()()
noexcept
{
x
^=
x
>>
12U
;
x
^=
x
<<
25U
;
x
^=
x
>>
27U
;
return
normalize
<
T
>
(
x
*
0x2545F4914F6CDD1D
);
}
};
template
<
class
T
>
template
<
class
T
>
std
::
vector
<
T
>
generate_tensor_data
(
const
migraph
::
shape
&
s
,
unsigned
long
seed
=
0
)
std
::
vector
<
T
>
generate_tensor_data
(
const
migraph
::
shape
&
s
,
unsigned
long
seed
=
0
)
{
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{
seed
});
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{
seed
});
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; });
return
result
;
return
result
;
}
}
...
...
src/include/migraph/instruction.hpp
View file @
15bf3d62
...
@@ -115,6 +115,11 @@ struct instruction
...
@@ -115,6 +115,11 @@ struct instruction
}
}
shape
get_shape
()
const
{
return
result
;
}
shape
get_shape
()
const
{
return
result
;
}
const
literal
&
get_literal
()
const
{
assert
(
op
.
name
()
==
"@literal"
);
return
lit
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
...
...
src/include/migraph/tracer.hpp
View file @
15bf3d62
...
@@ -2,17 +2,10 @@
...
@@ -2,17 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
#include <ostream>
#include <migraph/functional.hpp>
namespace
migraph
{
namespace
migraph
{
struct
swallow
{
template
<
class
...
Ts
>
swallow
(
Ts
&&
...)
{
}
};
struct
tracer
struct
tracer
{
{
tracer
()
{}
tracer
()
{}
...
...
src/include/migraph/verify_args.hpp
0 → 100644
View file @
15bf3d62
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraph/verify.hpp>
#include <migraph/argument.hpp>
namespace
migraph
{
inline
void
verify_args
(
const
std
::
string
&
name
,
const
argument
&
cpu_arg
,
const
argument
&
gpu_arg
,
double
tolerance
=
80
)
{
visit_all
(
cpu_arg
,
gpu_arg
)([
&
](
auto
cpu
,
auto
gpu
)
{
if
(
not
verify_range
(
cpu
,
gpu
,
tolerance
))
{
// TODO: Check for nans
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
if
(
cpu
.
size
()
<
32
)
std
::
cout
<<
"cpu:"
<<
cpu
<<
std
::
endl
;
if
(
gpu
.
size
()
<
32
)
std
::
cout
<<
"gpu:"
<<
gpu
<<
std
::
endl
;
if
(
range_zero
(
cpu
))
std
::
cout
<<
"Cpu data is all zeros"
<<
std
::
endl
;
if
(
range_zero
(
gpu
))
std
::
cout
<<
"Gpu data is all zeros"
<<
std
::
endl
;
auto
idx
=
mismatch_idx
(
cpu
,
gpu
,
float_equal
);
if
(
idx
<
range_distance
(
cpu
))
{
std
::
cout
<<
"Mismatch at "
<<
idx
<<
": "
<<
cpu
[
idx
]
<<
" != "
<<
gpu
[
idx
]
<<
std
::
endl
;
}
auto
cpu_nan_idx
=
find_idx
(
cpu
,
not_finite
);
if
(
cpu_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in cpu at "
<<
cpu_nan_idx
<<
": "
<<
cpu
[
cpu_nan_idx
]
<<
std
::
endl
;
auto
gpu_nan_idx
=
find_idx
(
gpu
,
not_finite
);
if
(
gpu_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in gpu at "
<<
gpu_nan_idx
<<
": "
<<
gpu
[
gpu_nan_idx
]
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
});
}
}
// namespace migraph
#endif
src/onnx/CMakeLists.txt
View file @
15bf3d62
...
@@ -17,11 +17,15 @@ rocm_clang_tidy_check(read_onnx)
...
@@ -17,11 +17,15 @@ rocm_clang_tidy_check(read_onnx)
target_link_libraries
(
read_onnx migraph_onnx
)
target_link_libraries
(
read_onnx migraph_onnx
)
if
(
MIGRAPH_ENABLE_GPU
)
add_executable
(
mnist mnist.cpp
)
add_executable
(
mnist mnist.cpp
)
rocm_clang_tidy_check
(
mnist
)
rocm_clang_tidy_check
(
mnist
)
target_link_libraries
(
mnist migraph_cpu migraph_onnx
)
target_link_libraries
(
mnist migraph_cpu migraph_gpu migraph_onnx
)
add_executable
(
cifar10 cifar10.cpp
)
rocm_clang_tidy_check
(
cifar10
)
target_link_libraries
(
cifar10 migraph_cpu migraph_gpu migraph_onnx
)
if
(
MIGRAPH_ENABLE_GPU
)
add_executable
(
verify_onnx verify_onnx.cpp
)
add_executable
(
verify_onnx verify_onnx.cpp
)
rocm_clang_tidy_check
(
verify_onnx
)
rocm_clang_tidy_check
(
verify_onnx
)
target_link_libraries
(
verify_onnx migraph_onnx migraph_cpu migraph_gpu
)
target_link_libraries
(
verify_onnx migraph_onnx migraph_cpu migraph_gpu
)
...
...
src/onnx/cifar10.cpp
0 → 100644
View file @
15bf3d62
#include <cstdio>
#include <string>
#include <fstream>
#include <numeric>
#include <stdexcept>
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include "softmax.hpp"
auto
read_cifar10_images
(
const
std
::
string
&
full_path
)
{
std
::
ifstream
file
(
full_path
,
std
::
ios
::
binary
);
const
size_t
nimages
=
10
;
const
size_t
nbytes_per_image
=
3072
;
std
::
vector
<
uint8_t
>
raw_data
(
nimages
*
(
nbytes_per_image
+
1
));
std
::
vector
<
uint8_t
>
labels
(
nimages
);
std
::
vector
<
float
>
data
(
nimages
*
nbytes_per_image
);
if
(
file
.
is_open
())
{
file
.
read
(
reinterpret_cast
<
char
*>
(
raw_data
.
data
()),
(
nbytes_per_image
+
1
)
*
nimages
*
sizeof
(
uint8_t
));
uint8_t
*
pimage
=
raw_data
.
data
();
for
(
size_t
i
=
0
;
i
<
nimages
;
i
++
,
pimage
+=
nbytes_per_image
)
{
labels
[
i
]
=
*
pimage
++
;
for
(
size_t
j
=
0
;
j
<
nbytes_per_image
;
j
++
)
{
float
v
=
*
(
pimage
+
j
)
/
255.0
f
;
data
[
i
*
nbytes_per_image
+
j
]
=
v
;
}
}
return
std
::
make_pair
(
labels
,
data
);
}
else
{
throw
std
::
runtime_error
(
"Cannot open file `"
+
full_path
+
"`!"
);
}
}
int
main
(
int
argc
,
char
const
*
argv
[])
{
if
(
argc
<
4
)
{
throw
std
::
runtime_error
(
"Usage: cifar10 [gpu | cpu] <onnx file> <cifar10 data file>"
);
}
std
::
string
gpu_cpu
=
argv
[
1
];
std
::
string
file
=
argv
[
2
];
std
::
string
datafile
=
argv
[
3
];
auto
prog
=
migraph
::
parse_onnx
(
file
);
std
::
cout
<<
prog
<<
std
::
endl
;
auto
imageset
=
read_cifar10_images
(
datafile
);
if
(
gpu_cpu
==
"gpu"
)
{
// GPU target
prog
.
compile
(
migraph
::
gpu
::
target
{});
migraph
::
program
::
parameter_map
m
;
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}};
for
(
auto
&&
x
:
prog
.
get_parameter_shapes
())
{
m
[
x
.
first
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
generate_argument
(
x
.
second
));
}
auto
labels
=
imageset
.
first
;
auto
input
=
imageset
.
second
;
auto
ptr
=
input
.
data
();
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
std
::
cout
<<
"label: "
<<
static_cast
<
uint32_t
>
(
labels
[
i
])
<<
" ----> "
;
m
[
"0"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
argument
{
s
,
&
ptr
[
3072
*
i
]});
auto
result
=
migraph
::
gpu
::
from_gpu
(
prog
.
eval
(
m
));
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
probs
=
softmax
<
float
>
(
logits
);
for
(
auto
x
:
probs
)
std
::
cout
<<
x
<<
" "
;
std
::
cout
<<
std
::
endl
<<
std
::
endl
;
}
}
else
{
// CPU target
prog
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}};
auto
labels
=
imageset
.
first
;
auto
input
=
imageset
.
second
;
auto
ptr
=
input
.
data
();
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
std
::
cout
<<
"label: "
<<
static_cast
<
uint32_t
>
(
labels
[
i
])
<<
" ----> "
;
auto
input3
=
migraph
::
argument
{
s
,
&
ptr
[
3072
*
i
]};
auto
result
=
prog
.
eval
({{
"0"
,
input3
}});
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
probs
=
softmax
<
float
>
(
logits
);
for
(
auto
x
:
probs
)
std
::
cout
<<
x
<<
" "
;
std
::
cout
<<
std
::
endl
;
}
}
}
src/onnx/mnist.cpp
View file @
15bf3d62
...
@@ -6,9 +6,12 @@
...
@@ -6,9 +6,12 @@
#include <migraph/onnx.hpp>
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraph/generate.hpp>
#include "softmax.hpp"
auto
reverse_int
(
unsigned
int
i
)
auto
reverse_int
(
unsigned
int
i
)
{
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
unsigned
char
c1
,
c2
,
c3
,
c4
;
...
@@ -97,16 +100,6 @@ std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number
...
@@ -97,16 +100,6 @@ std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number
}
}
}
}
std
::
vector
<
float
>
softmax
(
std
::
vector
<
float
>
p
)
{
size_t
n
=
p
.
size
();
std
::
vector
<
float
>
result
(
n
);
std
::
transform
(
p
.
begin
(),
p
.
end
(),
result
.
begin
(),
[](
auto
x
)
{
return
std
::
exp
(
x
);
});
float
s
=
std
::
accumulate
(
result
.
begin
(),
result
.
end
(),
0.0
f
,
std
::
plus
<
float
>
());
std
::
transform
(
result
.
begin
(),
result
.
end
(),
result
.
begin
(),
[
=
](
auto
x
)
{
return
x
/
s
;
});
return
result
;
}
int
main
(
int
argc
,
char
const
*
argv
[])
int
main
(
int
argc
,
char
const
*
argv
[])
{
{
if
(
argc
>
3
)
if
(
argc
>
3
)
...
@@ -121,15 +114,19 @@ int main(int argc, char const* argv[])
...
@@ -121,15 +114,19 @@ int main(int argc, char const* argv[])
std
::
string
file
=
argv
[
1
];
std
::
string
file
=
argv
[
1
];
auto
prog
=
migraph
::
parse_onnx
(
file
);
auto
prog
=
migraph
::
parse_onnx
(
file
);
prog
.
compile
(
migraph
::
cpu
::
cpu_target
{});
std
::
cout
<<
prog
<<
std
::
endl
<<
std
::
endl
;
prog
.
compile
(
migraph
::
gpu
::
target
{});
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
1
,
28
,
28
}};
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
1
,
28
,
28
}};
std
::
cout
<<
s
<<
std
::
endl
;
std
::
cout
<<
s
<<
std
::
endl
;
auto
ptr
=
input
.
data
();
auto
ptr
=
input
.
data
();
migraph
::
program
::
parameter_map
m
;
m
[
"output"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
generate_argument
(
prog
.
get_parameter_shape
(
"output"
)));
for
(
int
i
=
0
;
i
<
20
;
i
++
)
for
(
int
i
=
0
;
i
<
20
;
i
++
)
{
{
std
::
cout
<<
"label: "
<<
labels
[
i
]
<<
" ----> "
;
std
::
cout
<<
"label: "
<<
labels
[
i
]
<<
" ----> "
;
auto
input3
=
migraph
::
argument
{
s
,
&
ptr
[
784
*
i
]};
m
[
"0"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
argument
{
s
,
&
ptr
[
784
*
i
]}
)
;
auto
result
=
prog
.
eval
({{
"Input3"
,
input3
}}
);
auto
result
=
migraph
::
gpu
::
from_gpu
(
prog
.
eval
(
m
)
);
std
::
vector
<
float
>
logits
;
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
probs
=
softmax
(
logits
);
std
::
vector
<
float
>
probs
=
softmax
(
logits
);
...
...
src/onnx/onnx.cpp
View file @
15bf3d62
...
@@ -234,7 +234,7 @@ struct onnx_parser
...
@@ -234,7 +234,7 @@ struct onnx_parser
}
}
if
(
contains
(
attributes
,
"momentum"
))
if
(
contains
(
attributes
,
"momentum"
))
{
{
epsilon
=
parse_value
(
attributes
.
at
(
"momentum"
)).
at
<
float
>
();
momentum
=
parse_value
(
attributes
.
at
(
"momentum"
)).
at
<
float
>
();
}
}
if
(
contains
(
attributes
,
"is_test"
))
if
(
contains
(
attributes
,
"is_test"
))
{
{
...
...
src/onnx/softmax.hpp
0 → 100644
View file @
15bf3d62
#include <vector>
#include <algorithm>
#include <cmath>
template
<
typename
T
>
std
::
vector
<
T
>
softmax
(
const
std
::
vector
<
T
>&
p
)
{
size_t
n
=
p
.
size
();
std
::
vector
<
T
>
result
(
n
);
std
::
transform
(
p
.
begin
(),
p
.
end
(),
result
.
begin
(),
[](
auto
x
)
{
return
std
::
exp
(
x
);
});
T
s
=
std
::
accumulate
(
result
.
begin
(),
result
.
end
(),
0.0
f
,
std
::
plus
<
T
>
());
std
::
transform
(
result
.
begin
(),
result
.
end
(),
result
.
begin
(),
[
=
](
auto
x
)
{
return
x
/
s
;
});
return
result
;
}
src/onnx/verify_onnx.cpp
View file @
15bf3d62
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraph/generate.hpp>
#include <migraph/verify.hpp>
#include <migraph/verify
_args
.hpp>
migraph
::
argument
run_cpu
(
const
std
::
string
&
file
)
migraph
::
argument
run_cpu
(
const
std
::
string
&
file
)
{
{
...
@@ -46,20 +46,6 @@ int main(int argc, char const* argv[])
...
@@ -46,20 +46,6 @@ int main(int argc, char const* argv[])
auto
x
=
run_cpu
(
file
);
auto
x
=
run_cpu
(
file
);
auto
y
=
run_gpu
(
file
);
auto
y
=
run_gpu
(
file
);
visit_all
(
x
,
y
)([](
auto
cpu
,
auto
gpu
)
{
migraph
::
verify_args
(
file
,
x
,
y
,
100
);
if
(
migraph
::
verify_range
(
cpu
,
gpu
,
100
))
{
std
::
cout
<<
"Passed"
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"Not equal"
<<
std
::
endl
;
std
::
cout
<<
"cpu:"
<<
std
::
endl
;
std
::
cout
<<
cpu
<<
std
::
endl
;
std
::
cout
<<
"gpu:"
<<
std
::
endl
;
std
::
cout
<<
gpu
<<
std
::
endl
;
}
});
}
}
}
}
src/targets/gpu/CMakeLists.txt
View file @
15bf3d62
...
@@ -11,6 +11,7 @@ if(NOT TARGET MIOpen)
...
@@ -11,6 +11,7 @@ if(NOT TARGET MIOpen)
endif
()
endif
()
add_library
(
migraph_device
add_library
(
migraph_device
device/add.cpp
device/add_relu.cpp
device/add_relu.cpp
device/contiguous.cpp
device/contiguous.cpp
)
)
...
...
src/targets/gpu/device/add.cpp
0 → 100644
View file @
15bf3d62
#include <migraph/gpu/device/add.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace
migraph
{
namespace
gpu
{
namespace
device
{
void
add
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
nary
(
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
x
+
y
;
});
}
}
// namespace device
}
// namespace gpu
}
// namespace migraph
src/targets/gpu/device/add_relu.cpp
View file @
15bf3d62
...
@@ -5,10 +5,9 @@ namespace migraph {
...
@@ -5,10 +5,9 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
add_relu
(
argument
result
,
argument
arg1
,
argument
arg2
)
void
add_relu
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
nary_standard
(
std
::
move
(
result
),
std
::
move
(
arg1
),
std
::
move
(
arg2
))(
nary
(
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
std
::
max
<
decltype
(
x
+
y
)
>
(
0
,
x
+
y
);
});
[](
auto
x
,
auto
y
)
{
return
max
(
0
,
x
+
y
);
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
View file @
15bf3d62
...
@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local)
...
@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local)
};
};
}
}
inline
auto
gs_launch
(
std
::
size_t
n
,
std
::
size_t
local
=
512
)
inline
auto
gs_launch
(
std
::
size_t
n
,
std
::
size_t
local
=
1024
)
{
{
std
::
size_t
groups
=
1
+
n
/
local
;
std
::
size_t
groups
=
1
+
n
/
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
512
,
groups
)
*
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
256
,
groups
)
*
local
;
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
launch
(
nglobal
,
local
)([
=
](
auto
idx
)
{
launch
(
nglobal
,
local
)([
=
](
auto
idx
)
{
...
@@ -48,6 +48,14 @@ inline auto gs_launch(std::size_t n, std::size_t local = 512)
...
@@ -48,6 +48,14 @@ inline auto gs_launch(std::size_t n, std::size_t local = 512)
};
};
}
}
// Workaround hcc's broken tile_static macro
#ifdef tile_static
#undef tile_static
#define MIGRAPH_DEVICE_SHARED __attribute__((tile_static))
#else
#define MIGRAPH_DEVICE_SHARED __shared__
#endif
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
}
// namespace migraph
}
// namespace migraph
...
...
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
15bf3d62
...
@@ -10,16 +10,25 @@ namespace migraph {
...
@@ -10,16 +10,25 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
template
<
class
...
Arguments
>
template
<
class
T
>
auto
nary
(
argument
result
,
Arguments
...
args
)
using
vec4
=
T
__attribute__
((
ext_vector_type
(
4
)));
template
<
class
T
>
__device__
__host__
vec4
<
T
>*
as_vec4
(
T
*
x
)
{
{
return
[
=
](
auto
f
)
{
return
reinterpret_cast
<
vec4
<
T
>*>
(
x
);
if
(
all_of
({
args
...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
}))
}
nary_standard
(
result
,
args
...)(
f
);
else
nary_nonstandard
(
result
,
args
...)(
f
);
};
template
<
class
T
>
__device__
__host__
T
*
as_pointer
(
vec4
<
T
>*
x
)
{
return
reinterpret_cast
<
T
*>
(
x
);
}
template
<
class
...
Ts
>
auto
pack_vec4
(
Ts
...
xs
)
{
return
[
=
](
auto
f
,
std
::
size_t
n
)
{
return
f
(
as_vec4
(
xs
)[
n
]...);
};
}
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
...
@@ -28,14 +37,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
...
@@ -28,14 +37,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
auto
data
=
make_sequence
(
auto
data
=
pack
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
().
lens
(),
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
inputs
.
data
())...);
inputs
.
get_shape
().
strides
()},
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
inputs
.
data
())...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
.
lens
(),
output_shape
.
strides
());
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
data
([
&
](
auto
...
ps
)
{
data
([
&
](
auto
&&
...
ps
)
{
auto
outidx
=
out_desc
.
multi
(
i
);
auto
outidx
=
out_desc
.
multi
(
i
);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
});
});
...
@@ -44,24 +51,199 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
...
@@ -44,24 +51,199 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
});
});
}
}
template
<
class
...
Arguments
>
template
<
class
F
>
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
void
binary_broadcast_vec_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
auto
*
xp
=
as_vec4
(
input1
.
data
());
auto
*
yp
=
as_vec4
(
input2
.
data
());
auto
*
outp
=
as_vec4
(
output
.
data
());
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
vec4
<
type
>
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
bp
[
bidx
];
vec4
<
type
>
x
=
xp
[
i
];
vec4
<
type
>
out
=
outp
[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
x
[
j
],
b
);
}
outp
[
i
]
=
out
;
}
});
});
}
template
<
class
F
>
void
binary_broadcast_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
return
[
=
](
auto
f
)
{
return
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
type
x
=
xp
[
i
];
outp
[
i
]
=
f
(
x
,
b
);
}
});
});
}
}
template
<
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
void
nary_standard
_vec_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
auto
data
=
make_sequence
(
inputs
.
data
()...);
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
const
std
::
size_t
vec_size
=
4
;
auto
data
=
pack_vec4
(
inputs
.
data
()...);
auto
*
outp
=
as_vec4
(
output
.
data
());
gs_launch
(
output_shape
.
elements
()
/
vec_size
)([
=
](
auto
i
)
{
vec4
<
type
>
out
=
outp
[
i
];
data
(
[
&
](
auto
...
xs
)
{
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
xs
[
j
]...);
}
},
i
);
outp
[
i
]
=
out
;
});
});
}
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
auto
data
=
pack
(
inputs
.
data
()...);
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())(
gs_launch
(
output_shape
.
elements
())(
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
});
});
}
template
<
class
F
,
class
...
Arguments
>
void
nary_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
if
(
standard
or
(
packed
and
same_shapes
))
nary_standard_impl
(
f
,
result
,
args
...);
else
nary_nonstandard_impl
(
f
,
result
,
args
...);
}
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
}
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
f
,
result
,
args
...);
};
}
template
<
class
...
Arguments
>
auto
nary
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
nary_impl
(
f
,
result
,
args
...);
};
}
inline
auto
nary
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
return
[
=
](
auto
f
)
{
// TODO: Check result and arg1 shape is the same
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
broadcasted
())
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
arg2
.
get_shape
().
strides
();
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
arg2
.
get_shape
().
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
binary_broadcast_vec_impl
(
f
,
result
,
arg1
,
arg2
);
else
binary_broadcast_impl
(
f
,
result
,
arg1
,
arg2
);
return
;
}
}
nary_impl
(
f
,
result
,
arg1
,
arg2
);
};
};
}
}
...
...
src/targets/gpu/device/include/migraph/gpu/device/tensor.hpp
View file @
15bf3d62
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <migraph/functional.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
...
@@ -53,14 +54,13 @@ template <size_t NDim>
...
@@ -53,14 +54,13 @@ template <size_t NDim>
struct
hip_tensor_descriptor
struct
hip_tensor_descriptor
{
{
__device__
__host__
hip_tensor_descriptor
()
=
default
;
__device__
__host__
hip_tensor_descriptor
()
=
default
;
template
<
typename
T
,
typename
V
>
__device__
__host__
hip_tensor_descriptor
(
const
T
&
lens_ext
,
const
V
&
strides_ext
)
hip_tensor_descriptor
(
const
shape
&
s
)
{
{
for
(
size_t
i
=
0
;
i
<
NDim
;
i
++
)
std
::
copy
(
s
.
lens
().
begin
(),
s
.
lens
().
end
(),
lens
);
lens
[
i
]
=
lens_ext
[
i
];
std
::
copy
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
strides
);
for
(
size_t
i
=
0
;
i
<
NDim
;
i
++
)
strides
[
i
]
=
strides_ext
[
i
];
}
}
__device__
__host__
hip_index
<
NDim
>
multi
(
size_t
idx
)
const
__device__
__host__
hip_index
<
NDim
>
multi
(
size_t
idx
)
const
{
{
hip_index
<
NDim
>
result
{};
hip_index
<
NDim
>
result
{};
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment