Commit 63eb0da5 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

llama

parent e9128480
...@@ -434,6 +434,7 @@ PD_REGISTER_KERNEL(cumsum, ...@@ -434,6 +434,7 @@ PD_REGISTER_KERNEL(cumsum,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::CumsumKernel, phi::CumsumKernel,
phi::dtype::float16,
float, float,
double, double,
int16_t, int16_t,
......
...@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_WITH_HIP
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202
#include "paddle/phi/kernels/multinomial_kernel.h" #include "paddle/phi/kernels/multinomial_kernel.h"
#ifdef __NVCC__ #ifdef __NVCC__
...@@ -107,14 +103,22 @@ __global__ void sampleMultinomialWithReplacement( ...@@ -107,14 +103,22 @@ __global__ void sampleMultinomialWithReplacement(
size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x + size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x +
threadIdx.x; threadIdx.x;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state); curand_init(seed, idx, offset, &state);
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, offset, &state);
#endif
int sample = blockIdx.x * blockDim.x + threadIdx.x; int sample = blockIdx.x * blockDim.x + threadIdx.x;
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
if (sample < num_samples) { if (sample < num_samples) {
#if defined(__NVCC__)
T rng_number = static_cast<T>(curand_uniform4(&state).x); T rng_number = static_cast<T>(curand_uniform4(&state).x);
// Find the bucket that a uniform random number lies in #else
T rng_number = static_cast<T>(hiprand_uniform4(&state).x);
#endif
int selected_category = int selected_category =
binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories, binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories,
norm_probs_data + dist * num_categories, norm_probs_data + dist * num_categories,
...@@ -187,7 +191,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -187,7 +191,7 @@ void MultinomialKernel(const Context& dev_ctx,
if (int_num_samples == 1) { if (int_num_samples == 1) {
ArgMaxKernel<T, Context>( ArgMaxKernel<T, Context>(
dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out); dev_ctx, rand, -1, true, false, 3, out);
} else { } else {
std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims()); std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims());
DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec)); DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec));
...@@ -283,7 +287,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -283,7 +287,7 @@ void MultinomialKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(multinomial, // cuda_only PD_REGISTER_KERNEL(multinomial,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultinomialKernel, phi::MultinomialKernel,
...@@ -293,5 +297,3 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only ...@@ -293,5 +297,3 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
double) { double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64); kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
#endif
...@@ -183,9 +183,9 @@ def multinomial(x, num_samples=1, replacement=False, name=None): ...@@ -183,9 +183,9 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
""" """
assert ( # assert (
not core.is_compiled_with_rocm() # not core.is_compiled_with_rocm()
), "multinomial op is not supported on ROCM yet." # ), "multinomial op is not supported on ROCM yet."
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.multinomial(x, num_samples, replacement) return _C_ops.multinomial(x, num_samples, replacement)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment