clip.h 3.1 KB
Newer Older
goldenfox2025's avatar
goldenfox2025 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#ifndef __CLIP_H__
#define __CLIP_H__

#include "../../elementwise/elementwise.h"
#include "../../operator.h"

/**
 * @brief Define the Clip descriptor for the ternary operator
 *
 * This macro defines a Descriptor class for the Clip operator that inherits from InfiniopDescriptor.
 * It uses the standard elementwise operation fields and methods for a ternary operator
 * where min_val and max_val are tensors.
 *
 * @param OP The operator name (clip)
 * @param NAMESPACE The namespace (cpu or cuda)
 */
#define CLIP_DESCRIPTOR(OP, NAMESPACE)                                        \
                                                                              \
    namespace op::OP::NAMESPACE {                                             \
    class Descriptor final : public InfiniopDescriptor {                      \
        infiniDtype_t _dtype;                                                 \
        op::elementwise::ElementwiseInfo _info;                               \
        std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
        size_t _workspace_size;                                               \
                                                                              \
    public:                                                                   \
        Descriptor(                                                           \
            infiniDtype_t dtype,                                              \
            op::elementwise::ElementwiseInfo info,                            \
            op::elementwise::NAMESPACE::DeviceImpl *device_info,              \
            size_t workspace_size,                                            \
            infiniDevice_t device_type,                                       \
            int device_id)                                                    \
            : InfiniopDescriptor{device_type, device_id},                     \
              _dtype(dtype),                                                  \
              _info(std::move(info)),                                         \
              _device_info(std::move(device_info)),                           \
              _workspace_size(workspace_size) {}                              \
                                                                              \
        ~Descriptor();                                                        \
                                                                              \
        size_t workspaceSize() const { return _workspace_size; }              \
                                                                              \
        infiniStatus_t calculate(                                             \
            void *workspace, size_t workspace_size,                           \
            void *output,                                                     \
            std::vector<const void *> inputs,                                 \
            void *stream) const;                                              \
    };                                                                        \
    }

#endif // __CLIP_H__