Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e8afe91a
Commit
e8afe91a
authored
May 31, 2019
by
Shucai Xiao
Browse files
clang format
parent
1e731018
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
43 deletions
+35
-43
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+1
-1
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+10
-10
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+24
-32
No files found.
src/targets/cpu/lowering.cpp
View file @
e8afe91a
...
@@ -578,7 +578,7 @@ struct cpu_logsoftmax
...
@@ -578,7 +578,7 @@ struct cpu_logsoftmax
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
auto
batch_lens
=
output_shape
.
lens
();
auto
batch_lens
=
output_shape
.
lens
();
batch_lens
[
op
.
axis
]
=
1
;
batch_lens
[
op
.
axis
]
=
1
;
shape
batch_shape
{
shape
::
int32_type
,
batch_lens
};
shape
batch_shape
{
shape
::
int32_type
,
batch_lens
};
...
...
src/targets/gpu/device/logsoftmax.cpp
View file @
e8afe91a
...
@@ -17,10 +17,10 @@ argument logsoftmax(hipStream_t stream,
...
@@ -17,10 +17,10 @@ argument logsoftmax(hipStream_t stream,
int
axis
)
int
axis
)
{
{
auto
lens
=
output_shape
.
lens
();
auto
lens
=
output_shape
.
lens
();
auto
num_in_batch
=
lens
[
axis
];
auto
num_in_batch
=
lens
[
axis
];
auto
batch_lens
=
lens
;
auto
batch_lens
=
lens
;
batch_lens
[
axis
]
=
1
;
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
output_shape
.
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
output_shape
.
type
(),
batch_lens
};
visit_all
(
args
.
back
(),
args
.
front
())([
&
](
auto
output
,
auto
input
)
{
visit_all
(
args
.
back
(),
args
.
front
())([
&
](
auto
output
,
auto
input
)
{
...
@@ -33,21 +33,21 @@ argument logsoftmax(hipStream_t stream,
...
@@ -33,21 +33,21 @@ argument logsoftmax(hipStream_t stream,
// each thread is for one item in the batch
// each thread is for one item in the batch
gs_launch
(
stream
,
batch_shape
.
elements
())([
=
](
auto
i
)
{
gs_launch
(
stream
,
batch_shape
.
elements
())([
=
](
auto
i
)
{
auto
batch_idx
=
desc_batch
.
multi
(
i
);
auto
batch_idx
=
desc_batch
.
multi
(
i
);
auto
data_idx
=
batch_idx
;
auto
data_idx
=
batch_idx
;
// get max
// get max
auto
batch_max
=
input_ptr
[
desc_data
.
linear
(
batch_idx
)];
auto
batch_max
=
input_ptr
[
desc_data
.
linear
(
batch_idx
)];
for
(
std
::
size_t
j
=
1
;
j
<
num_in_batch
;
++
j
)
for
(
std
::
size_t
j
=
1
;
j
<
num_in_batch
;
++
j
)
{
{
data_idx
[
axis
]
=
j
;
data_idx
[
axis
]
=
j
;
size_t
idx
=
desc_data
.
linear
(
data_idx
);
size_t
idx
=
desc_data
.
linear
(
data_idx
);
batch_max
=
std
::
max
(
to_hip_type
(
batch_max
),
to_hip_type
(
input_ptr
[
idx
]));
batch_max
=
std
::
max
(
to_hip_type
(
batch_max
),
to_hip_type
(
input_ptr
[
idx
]));
}
}
for
(
std
::
size_t
j
=
0
;
j
<
num_in_batch
;
++
j
)
for
(
std
::
size_t
j
=
0
;
j
<
num_in_batch
;
++
j
)
{
{
data_idx
[
axis
]
=
j
;
data_idx
[
axis
]
=
j
;
size_t
idx
=
desc_data
.
linear
(
data_idx
);
size_t
idx
=
desc_data
.
linear
(
data_idx
);
output_ptr
[
idx
]
=
input_ptr
[
idx
]
-
batch_max
;
output_ptr
[
idx
]
=
input_ptr
[
idx
]
-
batch_max
;
}
}
...
@@ -55,7 +55,7 @@ argument logsoftmax(hipStream_t stream,
...
@@ -55,7 +55,7 @@ argument logsoftmax(hipStream_t stream,
for
(
std
::
size_t
j
=
1
;
j
<
num_in_batch
;
++
j
)
for
(
std
::
size_t
j
=
1
;
j
<
num_in_batch
;
++
j
)
{
{
data_idx
[
axis
]
=
j
;
data_idx
[
axis
]
=
j
;
size_t
idx
=
desc_data
.
linear
(
data_idx
);
size_t
idx
=
desc_data
.
linear
(
data_idx
);
batch_sum
+=
::
exp
(
to_hip_type
(
output_ptr
[
idx
]));
batch_sum
+=
::
exp
(
to_hip_type
(
output_ptr
[
idx
]));
}
}
batch_sum
=
::
log
(
to_hip_type
(
batch_sum
));
batch_sum
=
::
log
(
to_hip_type
(
batch_sum
));
...
@@ -63,7 +63,7 @@ argument logsoftmax(hipStream_t stream,
...
@@ -63,7 +63,7 @@ argument logsoftmax(hipStream_t stream,
for
(
std
::
size_t
j
=
0
;
j
<
num_in_batch
;
++
j
)
for
(
std
::
size_t
j
=
0
;
j
<
num_in_batch
;
++
j
)
{
{
data_idx
[
axis
]
=
j
;
data_idx
[
axis
]
=
j
;
size_t
idx
=
desc_data
.
linear
(
data_idx
);
size_t
idx
=
desc_data
.
linear
(
data_idx
);
output_ptr
[
idx
]
-=
batch_sum
;
output_ptr
[
idx
]
-=
batch_sum
;
}
}
});
});
...
...
test/cpu_ops_test.cpp
View file @
e8afe91a
...
@@ -1002,14 +1002,12 @@ TEST_CASE(logsoftmax_test_axis_0)
...
@@ -1002,14 +1002,12 @@ TEST_CASE(logsoftmax_test_axis_0)
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
s
=
{
std
::
vector
<
float
>
s
=
{
-
0.135261
,
-
2.843968
,
-
0.659995
,
-
0.488413
,
-
1.051857
,
-
2.812936
,
-
0.135261
,
-
2.843968
,
-
0.659995
,
-
0.488413
,
-
1.051857
,
-
2.812936
,
-
0.250956
,
-
0.353985
,
-
0.250956
,
-
0.353985
,
-
1.155980
,
-
0.603651
,
-
0.211969
,
-
0.175371
,
-
1.155980
,
-
0.603651
,
-
0.211969
,
-
0.175371
,
-
1.336552
,
-
3.885010
,
-
1.871544
,
-
0.837083
,
-
1.336552
,
-
3.885010
,
-
1.871544
,
-
0.837083
,
-
0.887745
,
-
0.433338
,
-
0.887745
,
-
0.433338
,
-
1.158864
,
-
4.911197
,
-
1.147972
,
-
0.666711
,
-
0.996874
,
-
0.981418
,
-
1.158864
,
-
4.911197
,
-
1.147972
,
-
0.666711
,
-
0.996874
,
-
0.981418
,
-
0.851145
,
-
0.853988
,
-
0.858112
,
-
2.067420
,
-
0.059956
,
-
0.727436
,
-
0.950881
,
-
0.429689
,
-
0.851145
,
-
0.853988
,
-
0.858112
,
-
2.067420
,
-
0.059956
,
-
0.727436
,
-
0.061906
,
-
1.505332
,
-
1.210277
,
-
0.377970
,
-
0.791448
,
-
1.655428
,
-
1.827253
,
-
0.304828
,
-
0.950881
,
-
0.429689
,
-
0.061906
,
-
1.505332
,
-
1.210277
,
-
0.377970
,
-
0.020762
,
-
0.167101
,
-
0.567346
,
-
0.530319
,
-
1.045094
,
-
0.376648
,
-
0.007391
,
-
0.381670
,
-
0.791448
,
-
1.655428
,
-
1.827253
,
-
0.304828
,
-
0.020762
,
-
0.167101
,
-
0.567346
,
-
0.530319
,
-
1.045094
,
-
0.376648
,
-
0.007391
,
-
0.381670
,
-
0.720302
,
-
0.460499
,
-
0.469651
,
-
0.556740
,
-
0.554628
,
-
0.551582
};
-
0.720302
,
-
0.460499
,
-
0.469651
,
-
0.556740
,
-
0.554628
,
-
0.551582
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
...
@@ -1037,14 +1035,12 @@ TEST_CASE(logsoftmax_test_axis_1)
...
@@ -1037,14 +1035,12 @@ TEST_CASE(logsoftmax_test_axis_1)
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
s
=
{
std
::
vector
<
float
>
s
=
{
-
0.550468
,
-
2.132973
,
-
1.549746
,
-
0.650533
,
-
1.051529
,
-
2.248570
,
-
0.550468
,
-
2.132973
,
-
1.549746
,
-
0.650533
,
-
1.051529
,
-
2.248570
,
-
0.141017
,
-
2.028357
,
-
0.141017
,
-
2.028357
,
-
1.947730
,
-
1.511324
,
-
0.166597
,
-
0.379726
,
-
1.947730
,
-
1.511324
,
-
0.166597
,
-
0.379726
,
-
1.965689
,
-
1.172109
,
-
1.475721
,
-
2.700831
,
-
1.965689
,
-
1.172109
,
-
1.475721
,
-
2.700831
,
-
1.537011
,
-
0.658754
,
-
1.537011
,
-
0.658754
,
-
1.596017
,
-
3.353137
,
-
2.266743
,
-
1.084197
,
-
1.076214
,
-
0.406712
,
-
1.596017
,
-
3.353137
,
-
2.266743
,
-
1.084197
,
-
1.076214
,
-
0.406712
,
-
2.743019
,
-
0.425526
,
-
1.079083
,
-
2.139486
,
-
1.270584
,
-
1.024088
,
-
1.154231
,
-
3.201762
,
-
2.743019
,
-
0.425526
,
-
1.079083
,
-
2.139486
,
-
1.270584
,
-
1.024088
,
-
0.888957
,
-
0.532855
,
-
3.103583
,
-
1.221339
,
-
1.355980
,
-
3.531678
,
-
1.438510
,
-
0.975194
,
-
1.154231
,
-
3.201762
,
-
0.888957
,
-
0.532855
,
-
3.103583
,
-
1.221339
,
-
0.080261
,
-
1.162697
,
-
1.568557
,
-
1.398519
,
-
1.322129
,
-
0.470660
,
-
0.370953
,
-
0.907343
,
-
1.355980
,
-
3.531678
,
-
1.438510
,
-
0.975194
,
-
0.080261
,
-
1.162697
,
-
1.568557
,
-
1.398519
,
-
1.322129
,
-
0.470660
,
-
0.370953
,
-
0.907343
,
-
1.179017
,
-
3.312239
,
-
1.286363
,
-
1.586076
,
-
0.345100
,
-
0.824173
};
-
1.179017
,
-
3.312239
,
-
1.286363
,
-
1.586076
,
-
0.345100
,
-
0.824173
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
...
@@ -1072,14 +1068,12 @@ TEST_CASE(logsoftmax_test_axis_2)
...
@@ -1072,14 +1068,12 @@ TEST_CASE(logsoftmax_test_axis_2)
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
s
=
{
std
::
vector
<
float
>
s
=
{
-
0.495957
,
-
1.031212
,
-
0.245531
,
-
2.013726
,
-
1.339125
,
-
2.465619
,
-
0.495957
,
-
1.031212
,
-
0.245531
,
-
2.013726
,
-
1.339125
,
-
2.465619
,
-
1.356652
,
-
0.964037
,
-
1.356652
,
-
0.964037
,
-
2.019250
,
-
0.214522
,
-
0.289569
,
-
0.234392
,
-
2.019250
,
-
0.214522
,
-
0.289569
,
-
0.234392
,
-
2.086591
,
-
2.684439
,
-
2.851651
,
-
2.674176
,
-
2.086591
,
-
2.684439
,
-
2.851651
,
-
2.674176
,
-
1.697424
,
-
1.889155
,
-
1.697424
,
-
1.889155
,
-
0.401029
,
-
3.064586
,
-
1.173030
,
-
1.306912
,
-
2.177020
,
-
0.834262
,
-
0.401029
,
-
3.064586
,
-
1.173030
,
-
1.306912
,
-
2.177020
,
-
0.834262
,
-
2.818177
,
-
0.174415
,
-
1.361105
,
-
1.024571
,
-
0.106766
,
-
1.167645
,
-
1.072650
,
-
2.576522
,
-
2.818177
,
-
0.174415
,
-
1.361105
,
-
1.024571
,
-
0.106766
,
-
1.167645
,
-
0.569261
,
-
1.207483
,
-
3.679894
,
-
2.095913
,
-
0.504264
,
-
3.039291
,
-
1.290559
,
-
1.156812
,
-
1.072650
,
-
2.576522
,
-
0.569261
,
-
1.207483
,
-
3.679894
,
-
2.095913
,
-
0.126453
,
-
0.551493
,
-
2.506384
,
-
2.646261
,
-
1.905195
,
-
0.206994
,
-
0.191369
,
-
0.959754
,
-
0.504264
,
-
3.039291
,
-
1.290559
,
-
1.156812
,
-
0.126453
,
-
0.551493
,
-
2.506384
,
-
2.646261
,
-
1.905195
,
-
0.206994
,
-
0.191369
,
-
0.959754
,
-
1.948685
,
-
3.671233
,
-
0.875521
,
-
3.111952
,
-
1.905644
,
-
1.6076011
};
-
1.948685
,
-
3.671233
,
-
0.875521
,
-
3.111952
,
-
1.905644
,
-
1.6076011
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
...
@@ -1107,14 +1101,12 @@ TEST_CASE(logsoftmax_test_axis_3)
...
@@ -1107,14 +1101,12 @@ TEST_CASE(logsoftmax_test_axis_3)
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
s
=
{
std
::
vector
<
float
>
s
=
{
-
0.336904
,
-
3.475825
,
-
1.366154
,
-
0.279366
,
-
2.208430
,
-
2.010934
,
-
0.336904
,
-
3.475825
,
-
1.366154
,
-
0.279366
,
-
2.208430
,
-
2.010934
,
-
0.225511
,
-
2.436562
,
-
0.225511
,
-
2.436562
,
-
2.167785
,
-
1.572415
,
-
1.784104
,
-
0.470789
,
-
2.167785
,
-
1.572415
,
-
1.784104
,
-
0.470789
,
-
1.067459
,
-
1.801948
,
-
0.711023
,
-
2.307197
,
-
1.067459
,
-
1.801948
,
-
0.711023
,
-
2.307197
,
-
1.467087
,
-
0.400681
,
-
1.467087
,
-
0.400681
,
-
0.426983
,
-
3.740518
,
-
1.127681
,
-
1.078919
,
-
2.599005
,
-
0.534965
,
-
0.426983
,
-
3.740518
,
-
1.127681
,
-
1.078919
,
-
2.599005
,
-
0.534965
,
-
2.561400
,
-
0.567617
,
-
1.033025
,
-
2.097713
,
-
0.520463
,
-
1.262245
,
-
1.763230
,
-
2.607658
,
-
2.561400
,
-
0.567617
,
-
1.033025
,
-
2.097713
,
-
0.520463
,
-
1.262245
,
-
0.281299
,
-
0.814243
,
-
2.627210
,
-
0.724131
,
-
0.655704
,
-
2.123055
,
-
1.018163
,
-
2.480634
,
-
1.763230
,
-
2.607658
,
-
0.281299
,
-
0.814243
,
-
2.627210
,
-
0.724131
,
-
0.382599
,
-
1.451479
,
-
1.843102
,
-
0.915303
,
-
0.818078
,
-
1.316929
,
-
0.508875
,
-
2.033541
,
-
0.655704
,
-
2.123055
,
-
1.018163
,
-
2.480634
,
-
0.382599
,
-
1.451479
,
-
1.843102
,
-
0.915303
,
-
0.818078
,
-
1.316929
,
-
0.508875
,
-
2.033541
,
-
1.487672
,
-
2.417791
,
-
0.378360
,
-
2.568531
,
-
0.569794
,
-
1.028032
};
-
1.487672
,
-
2.417791
,
-
0.378360
,
-
2.568531
,
-
0.569794
,
-
1.028032
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment