rand_uniform.hpp 3.78 KB
Newer Older
Brian Pickrell's avatar
Brian Pickrell committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

/**
Brian Pickrell's avatar
Brian Pickrell committed
26
27
28
 * Random Uniform distribution operator.  Given a shape, populate it with random
 * values.  Calls to rand_uniform using the same randomization seed will
 * always generate the same pseudo-random sequence.  Seed can
Brian Pickrell's avatar
Brian Pickrell committed
29
30
31
 * be given as a runtime argument containing a single value, or a compile-time
 * attribute.
 *
Brian Pickrell's avatar
Brian Pickrell committed
32
33
 *      Inputs:   (1) randomization seed (uint32)
 *                (2) the shape of the set to be populated.
Brian Pickrell's avatar
Brian Pickrell committed
34
 *
35
 *
Brian Pickrell's avatar
Brian Pickrell committed
36
 *      Attributes:  none
Brian Pickrell's avatar
Brian Pickrell committed
37
38
39
 *
 *      Output:   Same shape.
 *
Brian Pickrell's avatar
Brian Pickrell committed
40
 */
41
42
#ifndef MIGRAPHX_GUARD_OPERATORS_RAND_UNIFORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_RAND_UNIFORM_HPP
Brian Pickrell's avatar
Brian Pickrell committed
43
44
45
46
47
48
49
50
51
52
53
54
55

#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/reflect.hpp>
#include <random>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct rand_uniform
{
Brian Pickrell's avatar
Brian Pickrell committed
56
57
    // The rand_uniform operation does not contain a random number generator seed
    // as a member, and expects it to be passed as a runtime input.
Brian Pickrell's avatar
Brian Pickrell committed
58
59

    // todo:  not currently settable
Brian Pickrell's avatar
Brian Pickrell committed
60
61
    float range_min = 0.0f;
    float range_max = 1.0f;
Brian Pickrell's avatar
Brian Pickrell committed
62
63

    // todo:  integer data type(s) not yet supported
Brian Pickrell's avatar
Brian Pickrell committed
64
    shape::type_t dtype = shape::type_t::float_type;
Brian Pickrell's avatar
Brian Pickrell committed
65
66
67
68

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Brian Pickrell's avatar
Brian Pickrell committed
69
        return pack(f(self.dtype, "dtype"));
Brian Pickrell's avatar
Brian Pickrell committed
70
71
    }

Brian Pickrell's avatar
Brian Pickrell committed
72
73
74
75
    /**
     *   Input 1:  seed
     *   Input 2:  output shape
     */
Brian Pickrell's avatar
Brian Pickrell committed
76
77
78
    std::string name() const { return "rand_uniform"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Brian Pickrell's avatar
Brian Pickrell committed
79
        check_shapes{inputs, *this, true}.has(2);
Brian Pickrell's avatar
Brian Pickrell committed
80

Brian Pickrell's avatar
Brian Pickrell committed
81
        if(inputs.front().type() != shape::type_t::uint32_type)
Brian Pickrell's avatar
Brian Pickrell committed
82
            MIGRAPHX_THROW("RAND_UNIFORM:  Input 2 (seed) must have type unsigned int");
Brian Pickrell's avatar
Brian Pickrell committed
83
        auto s = inputs.at(1);
Brian Pickrell's avatar
Brian Pickrell committed
84
85
86
87
88
89
90
91
92
93
        if(s.dynamic())
        {
            return s.with_type(dtype);
        }
        else
        {
            return s.with_lens(s.lens()).with_type(dtype);
        }
    }

Brian Pickrell's avatar
Brian Pickrell committed
94
    argument compute(const shape& output, std::vector<argument> args) const
Brian Pickrell's avatar
Brian Pickrell committed
95
    {
Brian Pickrell's avatar
Brian Pickrell committed
96
97
98
        // Output goes into the passed buffer, not the shape output
        (void) output;
        argument result{args[1].get_shape()};
Brian Pickrell's avatar
Brian Pickrell committed
99

Brian Pickrell's avatar
Brian Pickrell committed
100
        uint32_t local_seed = args[0].at<uint32_t>(0);
Brian Pickrell's avatar
Brian Pickrell committed
101
102
103

        std::mt19937 gen(local_seed);
        std::uniform_real_distribution<> dis(range_min, range_max);
Brian Pickrell's avatar
Brian Pickrell committed
104
105
        result.visit([&](auto output_shape) {
            std::generate(output_shape.begin(), output_shape.end(), [&]() { return dis(gen); });
Brian Pickrell's avatar
Brian Pickrell committed
106
107
108
        });
        return result;
    }
Brian Pickrell's avatar
Brian Pickrell committed
109

Brian Pickrell's avatar
Brian Pickrell committed
110
    std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; }
Brian Pickrell's avatar
Brian Pickrell committed
111
112
113
114
115
116
117
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif