random_uniform.hpp 4.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

/**
 * Random Uniform distribution operator.  Given a shape, populate it with random
 * values.  Calls to random_uniform using the same randomization seed as a
 * literal input will
 * always generate the same pseudo-random sequence.
 *
 *      Inputs:   (1) randomization seed (any type is allowed)
 *                (2) output buffer argument to be populated.
 *
 *      Attributes:  none
 *
 *      Output:   Returns the buffer from input #2.
 *
 */
#ifndef MIGRAPHX_GUARD_OPERATORS_RANDOM_UNIFORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_RANDOM_UNIFORM_HPP

#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <random>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

/**
 * random_uniform populates the passed shape with random numbers, in a uniform
 * distribution.  Range for floating-point data types is (0, 1);
 * for integer types it is [0, <max value for the type>]
 */
struct random_uniform
{
    // The random_uniform operation needs the random number generator seed
    // to be passed as a runtime input.

    std::string name() const { return "random_uniform"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this, true}.has(2);

        return inputs.at(1);
    }

Brian Pickrell's avatar
Brian Pickrell committed
68
    argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
69
70
    {
        // Output goes into the passed buffer, not the shape output.
Brian Pickrell's avatar
Brian Pickrell committed
71
        argument result{dyn_out.computed_shape};
72
73
74
75
76
77
78
        uint64_t local_seed = args[0].at<uint64_t>(0);
        std::mt19937 gen(local_seed);

        result.visit([&](auto output) {
            using type = typename decltype(output)::value_type;
            if constexpr(std::is_integral<type>{})
            {
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#ifdef _MSC_VER
                // According to the C++ specification, the effect is undefined if the result type
                // for the generator is not one of short, int, long, long long, unsigned short,
                // unsigned int, unsigned long, or unsigned long long. See
                // https://en.cppreference.com/w/cpp/numeric/random/uniform_int_distribution.
                if constexpr(sizeof(type) == 1)
                {
                    std::uniform_int_distribution<int> dis{std::numeric_limits<type>::min(),
                                                           std::numeric_limits<type>::max()};
                    std::generate(output.begin(), output.end(), [&] { return dis(gen); });
                }
                else
#endif
                {
                    // default range for all integer types is
                    // (0, std::uniform_int_distribution<type>::max()).
                    // Todo:  enable different ranges
                    std::uniform_int_distribution<type> dis;
                    std::generate(output.begin(), output.end(), [&] { return dis(gen); });
                }
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            }
            else
            {
                // default real distribution type is double with range (0, 1);
                std::uniform_real_distribution<> dis;
                std::generate(output.begin(), output.end(), [&] { return dis(gen); });
            }
        });
        return result;
    }

    std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; }
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif