Commit 63eb0da5 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

llama

parent e9128480
......@@ -434,6 +434,7 @@ PD_REGISTER_KERNEL(cumsum,
GPU,
ALL_LAYOUT,
phi::CumsumKernel,
phi::dtype::float16,
float,
double,
int16_t,
......
......@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202
#include "paddle/phi/kernels/multinomial_kernel.h"
#ifdef __NVCC__
......@@ -107,14 +103,22 @@ __global__ void sampleMultinomialWithReplacement(
size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x +
threadIdx.x;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state);
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx, offset, &state);
#endif
int sample = blockIdx.x * blockDim.x + threadIdx.x;
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
if (sample < num_samples) {
#if defined(__NVCC__)
T rng_number = static_cast<T>(curand_uniform4(&state).x);
// Find the bucket that a uniform random number lies in
#else
T rng_number = static_cast<T>(hiprand_uniform4(&state).x);
#endif
int selected_category =
binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories,
norm_probs_data + dist * num_categories,
......@@ -187,7 +191,7 @@ void MultinomialKernel(const Context& dev_ctx,
if (int_num_samples == 1) {
ArgMaxKernel<T, Context>(
dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out);
dev_ctx, rand, -1, true, false, 3, out);
} else {
std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims());
DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec));
......@@ -283,7 +287,7 @@ void MultinomialKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(multinomial, // cuda_only
PD_REGISTER_KERNEL(multinomial,
GPU,
ALL_LAYOUT,
phi::MultinomialKernel,
......@@ -293,5 +297,3 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif
......@@ -183,9 +183,9 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
"""
assert (
not core.is_compiled_with_rocm()
), "multinomial op is not supported on ROCM yet."
# assert (
# not core.is_compiled_with_rocm()
# ), "multinomial op is not supported on ROCM yet."
if in_dynamic_mode():
return _C_ops.multinomial(x, num_samples, replacement)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment