Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
15bf3d62
Commit
15bf3d62
authored
Aug 30, 2018
by
Paul
Browse files
Merge branch 'master' into memory_coloring
parents
7fa4d978
d2778c9e
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
595 additions
and
77 deletions
+595
-77
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+66
-0
src/include/migraph/functional.hpp
src/include/migraph/functional.hpp
+49
-1
src/include/migraph/fwd_conv_batchnorm_rewrite.hpp
src/include/migraph/fwd_conv_batchnorm_rewrite.hpp
+19
-0
src/include/migraph/generate.hpp
src/include/migraph/generate.hpp
+23
-1
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+5
-0
src/include/migraph/tracer.hpp
src/include/migraph/tracer.hpp
+1
-8
src/include/migraph/verify_args.hpp
src/include/migraph/verify_args.hpp
+51
-0
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+6
-2
src/onnx/cifar10.cpp
src/onnx/cifar10.cpp
+107
-0
src/onnx/mnist.cpp
src/onnx/mnist.cpp
+11
-14
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+1
-1
src/onnx/softmax.hpp
src/onnx/softmax.hpp
+14
-0
src/onnx/verify_onnx.cpp
src/onnx/verify_onnx.cpp
+2
-16
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/device/add.cpp
src/targets/gpu/device/add.cpp
+15
-0
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+2
-3
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
+10
-2
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+205
-23
src/targets/gpu/device/include/migraph/gpu/device/tensor.hpp
src/targets/gpu/device/include/migraph/gpu/device/tensor.hpp
+6
-6
No files found.
src/CMakeLists.txt
View file @
15bf3d62
...
@@ -3,6 +3,7 @@ add_library(migraph
...
@@ -3,6 +3,7 @@ add_library(migraph
auto_contiguous.cpp
auto_contiguous.cpp
dead_code_elimination.cpp
dead_code_elimination.cpp
eliminate_contiguous.cpp
eliminate_contiguous.cpp
fwd_conv_batchnorm_rewrite.cpp
env.cpp
env.cpp
generate.cpp
generate.cpp
program.cpp
program.cpp
...
...
src/fwd_conv_batchnorm_rewrite.cpp
0 → 100644
View file @
15bf3d62
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace
migraph
{
void
fwd_conv_batchnorm_rewrite
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
op
.
name
()
!=
"batch_norm_inference"
)
continue
;
if
(
not
std
::
all_of
(
ins
->
arguments
.
begin
()
+
1
,
ins
->
arguments
.
end
(),
[](
auto
arg
)
{
return
arg
->
op
.
name
()
==
"@literal"
;
}))
continue
;
auto
conv_ins
=
ins
->
arguments
[
0
];
if
(
conv_ins
->
op
.
name
()
!=
"convolution"
)
continue
;
if
(
conv_ins
->
arguments
[
1
]
->
op
.
name
()
!=
"@literal"
)
continue
;
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
arguments
[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
arguments
[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
arguments
[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
arguments
[
4
]
->
get_literal
();
// Get epsilon
auto
bn_op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
arguments
[
1
]
->
get_literal
();
// Get convolution op
auto
conv_op
=
conv_ins
->
op
;
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
argument
new_weights
{
weights
.
get_shape
()};
argument
new_bias
{
bias
.
get_shape
()};
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
auto
weights2
,
auto
gamma2
,
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
new_weights2
,
auto
new_bias2
)
{
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
(
k
)
/
std
::
sqrt
(
variance2
(
k
)
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
(
c
)
=
bias2
(
c
)
-
(
mean2
(
c
)
/
std
::
sqrt
(
variance2
(
c
)
+
epsilon
));
});
});
// Replace convolution instruction with updated weights
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
arguments
[
0
],
l_weights
});
auto
b
=
p
.
insert_instruction
(
ins
,
broadcast
{
1
},
c
,
l_bias
);
p
.
replace_instruction
(
ins
,
add
{},
{
c
,
b
});
}
}
}
// namespace migraph
src/include/migraph/functional.hpp
View file @
15bf3d62
...
@@ -5,6 +5,14 @@
...
@@ -5,6 +5,14 @@
namespace
migraph
{
namespace
migraph
{
struct
swallow
{
template
<
class
...
Ts
>
constexpr
swallow
(
Ts
&&
...)
{
}
};
namespace
detail
{
namespace
detail
{
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
...
@@ -19,8 +27,48 @@ struct fix_f
...
@@ -19,8 +27,48 @@ struct fix_f
}
}
};
};
template
<
std
::
size_t
...>
struct
seq
{
using
type
=
seq
;
};
template
<
class
,
class
>
struct
merge_seq
;
template
<
std
::
size_t
...
Xs
,
std
::
size_t
...
Ys
>
struct
merge_seq
<
seq
<
Xs
...
>
,
seq
<
Ys
...
>>
:
seq
<
Xs
...,
(
sizeof
...(
Xs
)
+
Ys
)...
>
{
};
template
<
std
::
size_t
N
>
struct
gens
:
merge_seq
<
typename
gens
<
N
/
2
>::
type
,
typename
gens
<
N
-
N
/
2
>::
type
>
{
};
template
<
>
struct
gens
<
0
>
:
seq
<>
{
};
template
<
>
struct
gens
<
1
>
:
seq
<
0
>
{
};
template
<
class
F
,
std
::
size_t
...
Ns
>
constexpr
void
repeat_c_impl
(
F
f
,
seq
<
Ns
...
>
)
{
swallow
{(
f
(
std
::
integral_constant
<
std
::
size_t
,
Ns
>
{}),
0
)...};
}
}
// namespace detail
}
// namespace detail
template
<
std
::
size_t
N
,
class
F
>
constexpr
void
repeat_c
(
F
f
)
{
detail
::
repeat_c_impl
(
f
,
detail
::
gens
<
N
>
{});
}
/// Implements a fix-point combinator
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
...
@@ -35,7 +83,7 @@ auto fix(F f)
...
@@ -35,7 +83,7 @@ auto fix(F f)
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
make_sequence
(
Ts
...
xs
)
auto
pack
(
Ts
...
xs
)
{
{
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
}
...
...
src/include/migraph/fwd_conv_batchnorm_rewrite.hpp
0 → 100644
View file @
15bf3d62
#ifndef MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace
migraph
{
struct
program
;
struct
fwd_conv_batchnorm_rewrite
{
std
::
string
name
()
const
{
return
"fwd_conv_batchnorm_rewrite"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/include/migraph/generate.hpp
View file @
15bf3d62
...
@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z)
...
@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z)
{
{
if
(
z
==
0
)
if
(
z
==
0
)
return
0
;
return
0
;
return
(
2.0
/
z
)
-
1.0
;
const
auto
max
=
32768
;
const
double
range
=
max
/
2
;
// NOLINT
double
result
=
(
z
%
max
)
/
range
;
result
-=
1
;
return
result
;
}
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_signed
<
T
>{}
and
not
std
::
is_floating_point
<
T
>
{})
>
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_signed
<
T
>{}
and
not
std
::
is_floating_point
<
T
>
{})
>
...
@@ -54,11 +58,29 @@ struct xorshf96_generator
...
@@ -54,11 +58,29 @@ struct xorshf96_generator
}
}
};
};
template
<
class
T
>
struct
xorshift_generator
{
unsigned
long
x
;
xorshift_generator
(
unsigned
long
seed
=
0
)
:
x
(
521288629ULL
^
seed
)
{}
constexpr
T
operator
()()
noexcept
{
x
^=
x
>>
12U
;
x
^=
x
<<
25U
;
x
^=
x
>>
27U
;
return
normalize
<
T
>
(
x
*
0x2545F4914F6CDD1D
);
}
};
template
<
class
T
>
template
<
class
T
>
std
::
vector
<
T
>
generate_tensor_data
(
const
migraph
::
shape
&
s
,
unsigned
long
seed
=
0
)
std
::
vector
<
T
>
generate_tensor_data
(
const
migraph
::
shape
&
s
,
unsigned
long
seed
=
0
)
{
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{
seed
});
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{
seed
});
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; });
return
result
;
return
result
;
}
}
...
...
src/include/migraph/instruction.hpp
View file @
15bf3d62
...
@@ -115,6 +115,11 @@ struct instruction
...
@@ -115,6 +115,11 @@ struct instruction
}
}
shape
get_shape
()
const
{
return
result
;
}
shape
get_shape
()
const
{
return
result
;
}
const
literal
&
get_literal
()
const
{
assert
(
op
.
name
()
==
"@literal"
);
return
lit
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
...
...
src/include/migraph/tracer.hpp
View file @
15bf3d62
...
@@ -2,17 +2,10 @@
...
@@ -2,17 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
#include <ostream>
#include <migraph/functional.hpp>
namespace
migraph
{
namespace
migraph
{
struct
swallow
{
template
<
class
...
Ts
>
swallow
(
Ts
&&
...)
{
}
};
struct
tracer
struct
tracer
{
{
tracer
()
{}
tracer
()
{}
...
...
src/include/migraph/verify_args.hpp
0 → 100644
View file @
15bf3d62
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraph/verify.hpp>
#include <migraph/argument.hpp>
namespace
migraph
{
inline
void
verify_args
(
const
std
::
string
&
name
,
const
argument
&
cpu_arg
,
const
argument
&
gpu_arg
,
double
tolerance
=
80
)
{
visit_all
(
cpu_arg
,
gpu_arg
)([
&
](
auto
cpu
,
auto
gpu
)
{
if
(
not
verify_range
(
cpu
,
gpu
,
tolerance
))
{
// TODO: Check for nans
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
if
(
cpu
.
size
()
<
32
)
std
::
cout
<<
"cpu:"
<<
cpu
<<
std
::
endl
;
if
(
gpu
.
size
()
<
32
)
std
::
cout
<<
"gpu:"
<<
gpu
<<
std
::
endl
;
if
(
range_zero
(
cpu
))
std
::
cout
<<
"Cpu data is all zeros"
<<
std
::
endl
;
if
(
range_zero
(
gpu
))
std
::
cout
<<
"Gpu data is all zeros"
<<
std
::
endl
;
auto
idx
=
mismatch_idx
(
cpu
,
gpu
,
float_equal
);
if
(
idx
<
range_distance
(
cpu
))
{
std
::
cout
<<
"Mismatch at "
<<
idx
<<
": "
<<
cpu
[
idx
]
<<
" != "
<<
gpu
[
idx
]
<<
std
::
endl
;
}
auto
cpu_nan_idx
=
find_idx
(
cpu
,
not_finite
);
if
(
cpu_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in cpu at "
<<
cpu_nan_idx
<<
": "
<<
cpu
[
cpu_nan_idx
]
<<
std
::
endl
;
auto
gpu_nan_idx
=
find_idx
(
gpu
,
not_finite
);
if
(
gpu_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in gpu at "
<<
gpu_nan_idx
<<
": "
<<
gpu
[
gpu_nan_idx
]
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
});
}
}
// namespace migraph
#endif
src/onnx/CMakeLists.txt
View file @
15bf3d62
...
@@ -17,11 +17,15 @@ rocm_clang_tidy_check(read_onnx)
...
@@ -17,11 +17,15 @@ rocm_clang_tidy_check(read_onnx)
target_link_libraries
(
read_onnx migraph_onnx
)
target_link_libraries
(
read_onnx migraph_onnx
)
if
(
MIGRAPH_ENABLE_GPU
)
add_executable
(
mnist mnist.cpp
)
add_executable
(
mnist mnist.cpp
)
rocm_clang_tidy_check
(
mnist
)
rocm_clang_tidy_check
(
mnist
)
target_link_libraries
(
mnist migraph_cpu migraph_onnx
)
target_link_libraries
(
mnist migraph_cpu migraph_gpu migraph_onnx
)
add_executable
(
cifar10 cifar10.cpp
)
rocm_clang_tidy_check
(
cifar10
)
target_link_libraries
(
cifar10 migraph_cpu migraph_gpu migraph_onnx
)
if
(
MIGRAPH_ENABLE_GPU
)
add_executable
(
verify_onnx verify_onnx.cpp
)
add_executable
(
verify_onnx verify_onnx.cpp
)
rocm_clang_tidy_check
(
verify_onnx
)
rocm_clang_tidy_check
(
verify_onnx
)
target_link_libraries
(
verify_onnx migraph_onnx migraph_cpu migraph_gpu
)
target_link_libraries
(
verify_onnx migraph_onnx migraph_cpu migraph_gpu
)
...
...
src/onnx/cifar10.cpp
0 → 100644
View file @
15bf3d62
#include <cstdio>
#include <string>
#include <fstream>
#include <numeric>
#include <stdexcept>
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include "softmax.hpp"
auto
read_cifar10_images
(
const
std
::
string
&
full_path
)
{
std
::
ifstream
file
(
full_path
,
std
::
ios
::
binary
);
const
size_t
nimages
=
10
;
const
size_t
nbytes_per_image
=
3072
;
std
::
vector
<
uint8_t
>
raw_data
(
nimages
*
(
nbytes_per_image
+
1
));
std
::
vector
<
uint8_t
>
labels
(
nimages
);
std
::
vector
<
float
>
data
(
nimages
*
nbytes_per_image
);
if
(
file
.
is_open
())
{
file
.
read
(
reinterpret_cast
<
char
*>
(
raw_data
.
data
()),
(
nbytes_per_image
+
1
)
*
nimages
*
sizeof
(
uint8_t
));
uint8_t
*
pimage
=
raw_data
.
data
();
for
(
size_t
i
=
0
;
i
<
nimages
;
i
++
,
pimage
+=
nbytes_per_image
)
{
labels
[
i
]
=
*
pimage
++
;
for
(
size_t
j
=
0
;
j
<
nbytes_per_image
;
j
++
)
{
float
v
=
*
(
pimage
+
j
)
/
255.0
f
;
data
[
i
*
nbytes_per_image
+
j
]
=
v
;
}
}
return
std
::
make_pair
(
labels
,
data
);
}
else
{
throw
std
::
runtime_error
(
"Cannot open file `"
+
full_path
+
"`!"
);
}
}
int
main
(
int
argc
,
char
const
*
argv
[])
{
if
(
argc
<
4
)
{
throw
std
::
runtime_error
(
"Usage: cifar10 [gpu | cpu] <onnx file> <cifar10 data file>"
);
}
std
::
string
gpu_cpu
=
argv
[
1
];
std
::
string
file
=
argv
[
2
];
std
::
string
datafile
=
argv
[
3
];
auto
prog
=
migraph
::
parse_onnx
(
file
);
std
::
cout
<<
prog
<<
std
::
endl
;
auto
imageset
=
read_cifar10_images
(
datafile
);
if
(
gpu_cpu
==
"gpu"
)
{
// GPU target
prog
.
compile
(
migraph
::
gpu
::
target
{});
migraph
::
program
::
parameter_map
m
;
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}};
for
(
auto
&&
x
:
prog
.
get_parameter_shapes
())
{
m
[
x
.
first
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
generate_argument
(
x
.
second
));
}
auto
labels
=
imageset
.
first
;
auto
input
=
imageset
.
second
;
auto
ptr
=
input
.
data
();
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
std
::
cout
<<
"label: "
<<
static_cast
<
uint32_t
>
(
labels
[
i
])
<<
" ----> "
;
m
[
"0"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
argument
{
s
,
&
ptr
[
3072
*
i
]});
auto
result
=
migraph
::
gpu
::
from_gpu
(
prog
.
eval
(
m
));
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
probs
=
softmax
<
float
>
(
logits
);
for
(
auto
x
:
probs
)
std
::
cout
<<
x
<<
" "
;
std
::
cout
<<
std
::
endl
<<
std
::
endl
;
}
}
else
{
// CPU target
prog
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}};
auto
labels
=
imageset
.
first
;
auto
input
=
imageset
.
second
;
auto
ptr
=
input
.
data
();
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
std
::
cout
<<
"label: "
<<
static_cast
<
uint32_t
>
(
labels
[
i
])
<<
" ----> "
;
auto
input3
=
migraph
::
argument
{
s
,
&
ptr
[
3072
*
i
]};
auto
result
=
prog
.
eval
({{
"0"
,
input3
}});
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
probs
=
softmax
<
float
>
(
logits
);
for
(
auto
x
:
probs
)
std
::
cout
<<
x
<<
" "
;
std
::
cout
<<
std
::
endl
;
}
}
}
src/onnx/mnist.cpp
View file @
15bf3d62
...
@@ -6,9 +6,12 @@
...
@@ -6,9 +6,12 @@
#include <migraph/onnx.hpp>
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraph/generate.hpp>
#include "softmax.hpp"
auto
reverse_int
(
unsigned
int
i
)
auto
reverse_int
(
unsigned
int
i
)
{
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
unsigned
char
c1
,
c2
,
c3
,
c4
;
...
@@ -97,16 +100,6 @@ std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number
...
@@ -97,16 +100,6 @@ std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number
}
}
}
}
std
::
vector
<
float
>
softmax
(
std
::
vector
<
float
>
p
)
{
size_t
n
=
p
.
size
();
std
::
vector
<
float
>
result
(
n
);
std
::
transform
(
p
.
begin
(),
p
.
end
(),
result
.
begin
(),
[](
auto
x
)
{
return
std
::
exp
(
x
);
});
float
s
=
std
::
accumulate
(
result
.
begin
(),
result
.
end
(),
0.0
f
,
std
::
plus
<
float
>
());
std
::
transform
(
result
.
begin
(),
result
.
end
(),
result
.
begin
(),
[
=
](
auto
x
)
{
return
x
/
s
;
});
return
result
;
}
int
main
(
int
argc
,
char
const
*
argv
[])
int
main
(
int
argc
,
char
const
*
argv
[])
{
{
if
(
argc
>
3
)
if
(
argc
>
3
)
...
@@ -121,15 +114,19 @@ int main(int argc, char const* argv[])
...
@@ -121,15 +114,19 @@ int main(int argc, char const* argv[])
std
::
string
file
=
argv
[
1
];
std
::
string
file
=
argv
[
1
];
auto
prog
=
migraph
::
parse_onnx
(
file
);
auto
prog
=
migraph
::
parse_onnx
(
file
);
prog
.
compile
(
migraph
::
cpu
::
cpu_target
{});
std
::
cout
<<
prog
<<
std
::
endl
<<
std
::
endl
;
prog
.
compile
(
migraph
::
gpu
::
target
{});
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
1
,
28
,
28
}};
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
1
,
28
,
28
}};
std
::
cout
<<
s
<<
std
::
endl
;
std
::
cout
<<
s
<<
std
::
endl
;
auto
ptr
=
input
.
data
();
auto
ptr
=
input
.
data
();
migraph
::
program
::
parameter_map
m
;
m
[
"output"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
generate_argument
(
prog
.
get_parameter_shape
(
"output"
)));
for
(
int
i
=
0
;
i
<
20
;
i
++
)
for
(
int
i
=
0
;
i
<
20
;
i
++
)
{
{
std
::
cout
<<
"label: "
<<
labels
[
i
]
<<
" ----> "
;
std
::
cout
<<
"label: "
<<
labels
[
i
]
<<
" ----> "
;
auto
input3
=
migraph
::
argument
{
s
,
&
ptr
[
784
*
i
]};
m
[
"0"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
argument
{
s
,
&
ptr
[
784
*
i
]}
)
;
auto
result
=
prog
.
eval
({{
"Input3"
,
input3
}}
);
auto
result
=
migraph
::
gpu
::
from_gpu
(
prog
.
eval
(
m
)
);
std
::
vector
<
float
>
logits
;
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
probs
=
softmax
(
logits
);
std
::
vector
<
float
>
probs
=
softmax
(
logits
);
...
...
src/onnx/onnx.cpp
View file @
15bf3d62
...
@@ -234,7 +234,7 @@ struct onnx_parser
...
@@ -234,7 +234,7 @@ struct onnx_parser
}
}
if
(
contains
(
attributes
,
"momentum"
))
if
(
contains
(
attributes
,
"momentum"
))
{
{
epsilon
=
parse_value
(
attributes
.
at
(
"momentum"
)).
at
<
float
>
();
momentum
=
parse_value
(
attributes
.
at
(
"momentum"
)).
at
<
float
>
();
}
}
if
(
contains
(
attributes
,
"is_test"
))
if
(
contains
(
attributes
,
"is_test"
))
{
{
...
...
src/onnx/softmax.hpp
0 → 100644
View file @
15bf3d62
#include <vector>
#include <algorithm>
#include <cmath>
template
<
typename
T
>
std
::
vector
<
T
>
softmax
(
const
std
::
vector
<
T
>&
p
)
{
size_t
n
=
p
.
size
();
std
::
vector
<
T
>
result
(
n
);
std
::
transform
(
p
.
begin
(),
p
.
end
(),
result
.
begin
(),
[](
auto
x
)
{
return
std
::
exp
(
x
);
});
T
s
=
std
::
accumulate
(
result
.
begin
(),
result
.
end
(),
0.0
f
,
std
::
plus
<
T
>
());
std
::
transform
(
result
.
begin
(),
result
.
end
(),
result
.
begin
(),
[
=
](
auto
x
)
{
return
x
/
s
;
});
return
result
;
}
src/onnx/verify_onnx.cpp
View file @
15bf3d62
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraph/generate.hpp>
#include <migraph/verify.hpp>
#include <migraph/verify
_args
.hpp>
migraph
::
argument
run_cpu
(
const
std
::
string
&
file
)
migraph
::
argument
run_cpu
(
const
std
::
string
&
file
)
{
{
...
@@ -46,20 +46,6 @@ int main(int argc, char const* argv[])
...
@@ -46,20 +46,6 @@ int main(int argc, char const* argv[])
auto
x
=
run_cpu
(
file
);
auto
x
=
run_cpu
(
file
);
auto
y
=
run_gpu
(
file
);
auto
y
=
run_gpu
(
file
);
visit_all
(
x
,
y
)([](
auto
cpu
,
auto
gpu
)
{
migraph
::
verify_args
(
file
,
x
,
y
,
100
);
if
(
migraph
::
verify_range
(
cpu
,
gpu
,
100
))
{
std
::
cout
<<
"Passed"
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"Not equal"
<<
std
::
endl
;
std
::
cout
<<
"cpu:"
<<
std
::
endl
;
std
::
cout
<<
cpu
<<
std
::
endl
;
std
::
cout
<<
"gpu:"
<<
std
::
endl
;
std
::
cout
<<
gpu
<<
std
::
endl
;
}
});
}
}
}
}
src/targets/gpu/CMakeLists.txt
View file @
15bf3d62
...
@@ -11,6 +11,7 @@ if(NOT TARGET MIOpen)
...
@@ -11,6 +11,7 @@ if(NOT TARGET MIOpen)
endif
()
endif
()
add_library
(
migraph_device
add_library
(
migraph_device
device/add.cpp
device/add_relu.cpp
device/add_relu.cpp
device/contiguous.cpp
device/contiguous.cpp
)
)
...
...
src/targets/gpu/device/add.cpp
0 → 100644
View file @
15bf3d62
#include <migraph/gpu/device/add.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace
migraph
{
namespace
gpu
{
namespace
device
{
void
add
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
nary
(
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
x
+
y
;
});
}
}
// namespace device
}
// namespace gpu
}
// namespace migraph
src/targets/gpu/device/add_relu.cpp
View file @
15bf3d62
...
@@ -5,10 +5,9 @@ namespace migraph {
...
@@ -5,10 +5,9 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
add_relu
(
argument
result
,
argument
arg1
,
argument
arg2
)
void
add_relu
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
nary_standard
(
std
::
move
(
result
),
std
::
move
(
arg1
),
std
::
move
(
arg2
))(
nary
(
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
{
return
std
::
max
<
decltype
(
x
+
y
)
>
(
0
,
x
+
y
);
});
[](
auto
x
,
auto
y
)
{
return
max
(
0
,
x
+
y
);
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraph/gpu/device/launch.hpp
View file @
15bf3d62
...
@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local)
...
@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local)
};
};
}
}
inline
auto
gs_launch
(
std
::
size_t
n
,
std
::
size_t
local
=
512
)
inline
auto
gs_launch
(
std
::
size_t
n
,
std
::
size_t
local
=
1024
)
{
{
std
::
size_t
groups
=
1
+
n
/
local
;
std
::
size_t
groups
=
1
+
n
/
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
512
,
groups
)
*
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
256
,
groups
)
*
local
;
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
launch
(
nglobal
,
local
)([
=
](
auto
idx
)
{
launch
(
nglobal
,
local
)([
=
](
auto
idx
)
{
...
@@ -48,6 +48,14 @@ inline auto gs_launch(std::size_t n, std::size_t local = 512)
...
@@ -48,6 +48,14 @@ inline auto gs_launch(std::size_t n, std::size_t local = 512)
};
};
}
}
// Workaround hcc's broken tile_static macro
#ifdef tile_static
#undef tile_static
#define MIGRAPH_DEVICE_SHARED __attribute__((tile_static))
#else
#define MIGRAPH_DEVICE_SHARED __shared__
#endif
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
}
// namespace migraph
}
// namespace migraph
...
...
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
15bf3d62
...
@@ -10,16 +10,25 @@ namespace migraph {
...
@@ -10,16 +10,25 @@ namespace migraph {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
template
<
class
...
Arguments
>
template
<
class
T
>
auto
nary
(
argument
result
,
Arguments
...
args
)
using
vec4
=
T
__attribute__
((
ext_vector_type
(
4
)));
template
<
class
T
>
__device__
__host__
vec4
<
T
>*
as_vec4
(
T
*
x
)
{
{
return
[
=
](
auto
f
)
{
return
reinterpret_cast
<
vec4
<
T
>*>
(
x
);
if
(
all_of
({
args
...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
}))
}
nary_standard
(
result
,
args
...)(
f
);
else
nary_nonstandard
(
result
,
args
...)(
f
);
};
template
<
class
T
>
__device__
__host__
T
*
as_pointer
(
vec4
<
T
>*
x
)
{
return
reinterpret_cast
<
T
*>
(
x
);
}
template
<
class
...
Ts
>
auto
pack_vec4
(
Ts
...
xs
)
{
return
[
=
](
auto
f
,
std
::
size_t
n
)
{
return
f
(
as_vec4
(
xs
)[
n
]...);
};
}
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
...
@@ -28,14 +37,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
...
@@ -28,14 +37,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
output_shape
.
lens
().
size
(),
[
&
](
auto
ndim
)
{
auto
data
=
make_sequence
(
auto
data
=
pack
(
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
().
lens
(),
std
::
make_pair
(
hip_tensor_descriptor
<
ndim
>
{
inputs
.
get_shape
()},
inputs
.
data
())...);
inputs
.
get_shape
().
strides
()},
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
);
inputs
.
data
())...);
hip_tensor_descriptor
<
ndim
>
out_desc
(
output_shape
.
lens
(),
output_shape
.
strides
());
auto
*
outp
=
output
.
data
();
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
gs_launch
(
output_shape
.
elements
())([
=
](
auto
i
)
{
data
([
&
](
auto
...
ps
)
{
data
([
&
](
auto
&&
...
ps
)
{
auto
outidx
=
out_desc
.
multi
(
i
);
auto
outidx
=
out_desc
.
multi
(
i
);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
outp
[
i
]
=
f
(
ps
.
second
[
ps
.
first
.
linear
(
outidx
)]...);
});
});
...
@@ -44,24 +51,199 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
...
@@ -44,24 +51,199 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
});
});
}
}
template
<
class
F
>
void
binary_broadcast_vec_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
auto
*
xp
=
as_vec4
(
input1
.
data
());
auto
*
yp
=
as_vec4
(
input2
.
data
());
auto
*
outp
=
as_vec4
(
output
.
data
());
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
vec4
<
type
>
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
bp
[
bidx
];
vec4
<
type
>
x
=
xp
[
i
];
vec4
<
type
>
out
=
outp
[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
x
[
j
],
b
);
}
outp
[
i
]
=
out
;
}
});
});
}
template
<
class
F
>
void
binary_broadcast_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
type
x
=
xp
[
i
];
outp
[
i
]
=
f
(
x
,
b
);
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
const
std
::
size_t
vec_size
=
4
;
auto
data
=
pack_vec4
(
inputs
.
data
()...);
auto
*
outp
=
as_vec4
(
output
.
data
());
gs_launch
(
output_shape
.
elements
()
/
vec_size
)([
=
](
auto
i
)
{
vec4
<
type
>
out
=
outp
[
i
];
data
(
[
&
](
auto
...
xs
)
{
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
xs
[
j
]...);
}
},
i
);
outp
[
i
]
=
out
;
});
});
}
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
auto
data
=
pack
(
inputs
.
data
()...);
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())(
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
});
}
template
<
class
F
,
class
...
Arguments
>
void
nary_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
if
(
standard
or
(
packed
and
same_shapes
))
nary_standard_impl
(
f
,
result
,
args
...);
else
nary_nonstandard_impl
(
f
,
result
,
args
...);
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
return
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
f
,
result
,
args
...);
};
}
template
<
class
...
Arguments
>
auto
nary
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
nary_impl
(
f
,
result
,
args
...);
};
}
inline
auto
nary
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
// TODO: Check result and arg1 shape is the same
const
auto
&
output_shape
=
result
.
get_shape
();
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
broadcasted
())
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
{
auto
data
=
make_sequence
(
inputs
.
data
()...);
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
auto
*
outp
=
output
.
data
();
const
auto
&
strides
=
arg2
.
get_shape
().
strides
();
gs_launch
(
output_shape
.
elements
())(
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
});
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
arg2
.
get_shape
().
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
binary_broadcast_vec_impl
(
f
,
result
,
arg1
,
arg2
);
else
binary_broadcast_impl
(
f
,
result
,
arg1
,
arg2
);
return
;
}
}
nary_impl
(
f
,
result
,
arg1
,
arg2
);
};
};
}
}
...
...
src/targets/gpu/device/include/migraph/gpu/device/tensor.hpp
View file @
15bf3d62
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <migraph/functional.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
...
@@ -53,14 +54,13 @@ template <size_t NDim>
...
@@ -53,14 +54,13 @@ template <size_t NDim>
struct
hip_tensor_descriptor
struct
hip_tensor_descriptor
{
{
__device__
__host__
hip_tensor_descriptor
()
=
default
;
__device__
__host__
hip_tensor_descriptor
()
=
default
;
template
<
typename
T
,
typename
V
>
__device__
__host__
hip_tensor_descriptor
(
const
T
&
lens_ext
,
const
V
&
strides_ext
)
hip_tensor_descriptor
(
const
shape
&
s
)
{
{
for
(
size_t
i
=
0
;
i
<
NDim
;
i
++
)
std
::
copy
(
s
.
lens
().
begin
(),
s
.
lens
().
end
(),
lens
);
lens
[
i
]
=
lens_ext
[
i
];
std
::
copy
(
s
.
strides
().
begin
(),
s
.
strides
().
end
(),
strides
);
for
(
size_t
i
=
0
;
i
<
NDim
;
i
++
)
strides
[
i
]
=
strides_ext
[
i
];
}
}
__device__
__host__
hip_index
<
NDim
>
multi
(
size_t
idx
)
const
__device__
__host__
hip_index
<
NDim
>
multi
(
size_t
idx
)
const
{
{
hip_index
<
NDim
>
result
{};
hip_index
<
NDim
>
result
{};
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment