Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
63eb0da5
Commit
63eb0da5
authored
Dec 12, 2023
by
yuguo-Jack
Browse files
llama
parent
e9128480
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
12 deletions
+15
-12
paddle/phi/kernels/gpu/cum_kernel.cu
paddle/phi/kernels/gpu/cum_kernel.cu
+1
-0
paddle/phi/kernels/gpu/multinomial_kernel.cu
paddle/phi/kernels/gpu/multinomial_kernel.cu
+11
-9
python/paddle/tensor/random.py
python/paddle/tensor/random.py
+3
-3
No files found.
paddle/phi/kernels/gpu/cum_kernel.cu
View file @
63eb0da5
...
...
@@ -434,6 +434,7 @@ PD_REGISTER_KERNEL(cumsum,
GPU
,
ALL_LAYOUT
,
phi
::
CumsumKernel
,
phi
::
dtype
::
float16
,
float
,
double
,
int16_t
,
...
...
paddle/phi/kernels/gpu/multinomial_kernel.cu
View file @
63eb0da5
...
...
@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202
#include "paddle/phi/kernels/multinomial_kernel.h"
#ifdef __NVCC__
...
...
@@ -107,14 +103,22 @@ __global__ void sampleMultinomialWithReplacement(
size_t
idx
=
gridDim
.
x
*
blockDim
.
x
*
blockIdx
.
y
+
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
offset
,
&
state
);
#else
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
,
offset
,
&
state
);
#endif
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
dist
=
blockIdx
.
y
;
dist
<
num_distributions
;
dist
+=
gridDim
.
y
)
{
if
(
sample
<
num_samples
)
{
#if defined(__NVCC__)
T
rng_number
=
static_cast
<
T
>
(
curand_uniform4
(
&
state
).
x
);
// Find the bucket that a uniform random number lies in
#else
T
rng_number
=
static_cast
<
T
>
(
hiprand_uniform4
(
&
state
).
x
);
#endif
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs_data
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
...
...
@@ -187,7 +191,7 @@ void MultinomialKernel(const Context& dev_ctx,
if
(
int_num_samples
==
1
)
{
ArgMaxKernel
<
T
,
Context
>
(
dev_ctx
,
rand
,
-
1
,
true
,
false
,
3
/*proto::VarType::INT64*/
,
out
);
dev_ctx
,
rand
,
-
1
,
true
,
false
,
3
,
out
);
}
else
{
std
::
vector
<
int64_t
>
out_dim_vec
=
vectorize
<
int64_t
>
(
out
->
dims
());
DenseTensor
value
=
Empty
<
T
,
Context
>
(
dev_ctx
,
IntArray
(
out_dim_vec
));
...
...
@@ -283,7 +287,7 @@ void MultinomialKernel(const Context& dev_ctx,
}
// namespace phi
PD_REGISTER_KERNEL
(
multinomial
,
// cuda_only
PD_REGISTER_KERNEL
(
multinomial
,
GPU
,
ALL_LAYOUT
,
phi
::
MultinomialKernel
,
...
...
@@ -293,5 +297,3 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
double
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
INT64
);
}
#endif
python/paddle/tensor/random.py
View file @
63eb0da5
...
...
@@ -183,9 +183,9 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
"""
assert
(
not
core
.
is_compiled_with_rocm
()
),
"multinomial op is not supported on ROCM yet."
#
assert (
#
not core.is_compiled_with_rocm()
#
), "multinomial op is not supported on ROCM yet."
if
in_dynamic_mode
():
return
_C_ops
.
multinomial
(
x
,
num_samples
,
replacement
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment