Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ed530e11
Commit
ed530e11
authored
Sep 29, 2025
by
pengcheng888
Browse files
issue/427 - the sigmoid, topksoftmax, and topkrouter ops
parent
3959c943
Changes
38
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
954 additions
and
58 deletions
+954
-58
include/infiniop.h
include/infiniop.h
+2
-0
include/infiniop/ops/sigmoid.h
include/infiniop/ops/sigmoid.h
+24
-0
include/infiniop/ops/topkrouter.h
include/infiniop/ops/topkrouter.h
+14
-7
include/infiniop/ops/topksoftmax.h
include/infiniop/ops/topksoftmax.h
+26
-0
scripts/python_test.py
scripts/python_test.py
+3
-0
src/infiniop-test/include/ops.hpp
src/infiniop-test/include/ops.hpp
+6
-1
src/infiniop-test/src/ops/sigmoid.cpp
src/infiniop-test/src/ops/sigmoid.cpp
+103
-0
src/infiniop-test/src/ops/topkrouter.cpp
src/infiniop-test/src/ops/topkrouter.cpp
+130
-0
src/infiniop-test/src/ops/topksoftmax.cpp
src/infiniop-test/src/ops/topksoftmax.cpp
+122
-0
src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc
src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc
+51
-0
src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h
src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h
+19
-0
src/infiniop/ops/sigmoid/cuda/kernel.cuh
src/infiniop/ops/sigmoid/cuda/kernel.cuh
+34
-0
src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu
src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu
+58
-0
src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh
src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh
+8
-0
src/infiniop/ops/sigmoid/operator.cc
src/infiniop/ops/sigmoid/operator.cc
+115
-0
src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc
src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc
+202
-15
src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.h
src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.h
+2
-3
src/infiniop/ops/topkrouter/cuda/kernel.cuh
src/infiniop/ops/topkrouter/cuda/kernel.cuh
+20
-22
src/infiniop/ops/topkrouter/info.h
src/infiniop/ops/topkrouter/info.h
+2
-2
src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu
src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu
+13
-8
No files found.
include/infiniop.h
View file @
ed530e11
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
#include "infiniop/ops/sub.h"
#include "infiniop/ops/sub.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/ops/topkrouter.h"
#include "infiniop/ops/topkrouter.h"
#include "infiniop/ops/topksoftmax.h"
#include "infiniop/ops/sigmoid.h"
#include "infiniop/tensor_descriptor.h"
#include "infiniop/tensor_descriptor.h"
#endif // __INFINIOP_API_H__
#endif // __INFINIOP_API_H__
include/infiniop/ops/sigmoid.h
0 → 100644
View file @
ed530e11
#ifndef __INFINIOP_SIGMOID_API_H__
#define __INFINIOP_SIGMOID_API_H__
#include "../operator_descriptor.h"
typedef
struct
InfiniopDescriptor
*
infiniopSigmoidDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateSigmoidDescriptor
(
infiniopHandle_t
handle
,
infiniopSigmoidDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y
,
infiniopTensorDescriptor_t
x
);
__C
__export
infiniStatus_t
infiniopGetSigmoidWorkspaceSize
(
infiniopSigmoidDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopSigmoid
(
infiniopSigmoidDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroySigmoidDescriptor
(
infiniopSigmoidDescriptor_t
desc
);
#endif
include/infiniop/ops/topkrouter.h
View file @
ed530e11
...
@@ -5,16 +5,23 @@
...
@@ -5,16 +5,23 @@
typedef
struct
InfiniopDescriptor
*
infiniopTopkrouterDescriptor_t
;
typedef
struct
InfiniopDescriptor
*
infiniopTopkrouterDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateTopkrouterDescriptor
(
__C
__export
infiniStatus_t
infiniopCreateTopkrouterDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
infiniopTopkrouterDescriptor_t
*
desc_ptr
,
infiniopTopkrouterDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
correction_bias_desc
);
infiniopTensorDescriptor_t
correction_bias_desc
);
__C
__export
infiniStatus_t
infiniopGetTopkrouterWorkspaceSize
(
infiniopTopkrouterDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopGetTopkrouterWorkspaceSize
(
infiniopTopkrouterDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopTopkrouter
(
infiniopTopkrouterDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
__C
__export
infiniStatus_t
infiniopTopkrouter
(
infiniopTopkrouterDescriptor_t
desc
,
void
*
values
,
void
*
indices
,
void
*
x
,
void
*
correction_bias
,
float
routed_scaling_factor
,
size_t
topk
,
void
*
stream
);
void
*
workspace
,
size_t
workspace_size
,
void
*
values
,
void
*
indices
,
const
void
*
x
,
const
void
*
correction_bias
,
const
float
routed_scaling_factor
,
const
size_t
topk
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyTopkrouterDescriptor
(
infiniopTopkrouterDescriptor_t
desc
);
__C
__export
infiniStatus_t
infiniopDestroyTopkrouterDescriptor
(
infiniopTopkrouterDescriptor_t
desc
);
...
...
include/infiniop/ops/topksoftmax.h
0 → 100644
View file @
ed530e11
#ifndef __INFINIOP_TOPKSOFTMAX_API_H__
#define __INFINIOP_TOPKSOFTMAX_API_H__
#include "../operator_descriptor.h"
typedef
struct
InfiniopDescriptor
*
infiniopTopksoftmaxDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateTopksoftmaxDescriptor
(
infiniopHandle_t
handle
,
infiniopTopksoftmaxDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
x_desc
);
__C
__export
infiniStatus_t
infiniopGetTopksoftmaxWorkspaceSize
(
infiniopTopksoftmaxDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopTopksoftmax
(
infiniopTopksoftmaxDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
values
,
void
*
indices
,
const
void
*
x
,
const
size_t
topk
,
const
int
norm
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyTopksoftmaxDescriptor
(
infiniopTopksoftmaxDescriptor_t
desc
);
#endif
scripts/python_test.py
View file @
ed530e11
...
@@ -25,6 +25,9 @@ def run_tests(args):
...
@@ -25,6 +25,9 @@ def run_tests(args):
"sub.py"
,
"sub.py"
,
"swiglu.py"
,
"swiglu.py"
,
"softplus.py"
,
"softplus.py"
,
"sigmoid.py"
,
"topkrouter.py"
,
"topksoftmax.py"
,
]:
]:
result
=
subprocess
.
run
(
result
=
subprocess
.
run
(
f
"python
{
test
}
{
args
}
--debug"
,
text
=
True
,
encoding
=
"utf-8"
,
shell
=
True
f
"python
{
test
}
{
args
}
--debug"
,
text
=
True
,
encoding
=
"utf-8"
,
shell
=
True
...
...
src/infiniop-test/include/ops.hpp
View file @
ed530e11
...
@@ -16,7 +16,9 @@ DECLARE_INFINIOP_TEST(add)
...
@@ -16,7 +16,9 @@ DECLARE_INFINIOP_TEST(add)
DECLARE_INFINIOP_TEST
(
causal_softmax
)
DECLARE_INFINIOP_TEST
(
causal_softmax
)
DECLARE_INFINIOP_TEST
(
rearrange
)
DECLARE_INFINIOP_TEST
(
rearrange
)
DECLARE_INFINIOP_TEST
(
sub
)
DECLARE_INFINIOP_TEST
(
sub
)
DECLARE_INFINIOP_TEST
(
sigmoid
)
DECLARE_INFINIOP_TEST
(
topkrouter
)
DECLARE_INFINIOP_TEST
(
topksoftmax
)
#define REGISTER_INFINIOP_TEST(name) \
#define REGISTER_INFINIOP_TEST(name) \
{ \
{ \
#name, \
#name, \
...
@@ -43,6 +45,9 @@ DECLARE_INFINIOP_TEST(sub)
...
@@ -43,6 +45,9 @@ DECLARE_INFINIOP_TEST(sub)
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(sub) \
REGISTER_INFINIOP_TEST(sub) \
REGISTER_INFINIOP_TEST(sigmoid) \
REGISTER_INFINIOP_TEST(topkrouter) \
REGISTER_INFINIOP_TEST(topksoftmax) \
}
}
namespace
infiniop_test
{
namespace
infiniop_test
{
...
...
src/infiniop-test/src/ops/sigmoid.cpp
0 → 100644
View file @
ed530e11
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>
namespace
infiniop_test
::
sigmoid
{
struct
Test
::
Attributes
{
std
::
shared_ptr
<
Tensor
>
x
;
std
::
shared_ptr
<
Tensor
>
y
;
std
::
shared_ptr
<
Tensor
>
ans
;
};
std
::
shared_ptr
<
Test
>
Test
::
build
(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
uint8_t
>>
attributes
,
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
tensors
,
double
rtol
,
double
atol
)
{
auto
test
=
std
::
shared_ptr
<
Test
>
(
new
Test
(
rtol
,
atol
));
test
->
_attributes
=
new
Attributes
();
if
(
tensors
.
find
(
"x"
)
==
tensors
.
end
()
||
tensors
.
find
(
"y"
)
==
tensors
.
end
()
||
tensors
.
find
(
"ans"
)
==
tensors
.
end
())
{
throw
std
::
runtime_error
(
"Invalid Test"
);
}
test
->
_attributes
->
x
=
tensors
[
"x"
];
test
->
_attributes
->
y
=
tensors
[
"y"
];
test
->
_attributes
->
ans
=
tensors
[
"ans"
];
return
test
;
}
std
::
shared_ptr
<
infiniop_test
::
Result
>
Test
::
run
(
infiniopHandle_t
handle
,
infiniDevice_t
device
,
int
device_id
,
size_t
warm_ups
,
size_t
iterations
)
{
infiniopSigmoidDescriptor_t
op_desc
;
auto
x
=
_attributes
->
x
->
to
(
device
,
device_id
);
auto
y
=
_attributes
->
y
->
to
(
device
,
device_id
);
CHECK_OR
(
infiniopCreateSigmoidDescriptor
(
handle
,
&
op_desc
,
y
->
desc
(),
x
->
desc
()),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to create op descriptor."
));
size_t
workspace_size
;
CHECK_OR
(
infiniopGetSigmoidWorkspaceSize
(
op_desc
,
&
workspace_size
),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to get workspace size."
));
void
*
workspace
;
CHECK_OR
(
infinirtMalloc
(
&
workspace
,
workspace_size
),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to allocate workspace."
));
CHECK_OR
(
infiniopSigmoid
(
op_desc
,
workspace
,
workspace_size
,
y
->
data
(),
x
->
data
(),
nullptr
),
return
TEST_FAILED
(
OP_EXECUTION_FAILED
,
"Failed during execution."
));
try
{
allClose
(
y
,
_attributes
->
ans
,
_rtol
,
_atol
);
}
catch
(
const
std
::
exception
&
e
)
{
return
TEST_FAILED
(
RESULT_INCORRECT
,
e
.
what
());
}
double
elapsed_time
=
0.
;
elapsed_time
=
benchmark
(
[
=
]()
{
infiniopSigmoid
(
op_desc
,
workspace
,
workspace_size
,
y
->
data
(),
x
->
data
(),
nullptr
);
},
warm_ups
,
iterations
);
infiniopDestroySigmoidDescriptor
(
op_desc
);
infinirtFree
(
workspace
);
return
TEST_PASSED
(
elapsed_time
);
}
std
::
vector
<
std
::
string
>
Test
::
attribute_names
()
{
return
{};
}
std
::
vector
<
std
::
string
>
Test
::
tensor_names
()
{
return
{
"x"
,
"y"
,
"ans"
};
}
std
::
vector
<
std
::
string
>
Test
::
output_names
()
{
return
{
"y"
};
}
std
::
string
Test
::
toString
()
const
{
std
::
ostringstream
oss
;
oss
<<
op_name
()
<<
std
::
endl
;
oss
<<
"- x: "
<<
_attributes
->
x
->
info
()
<<
std
::
endl
;
oss
<<
"- y: "
<<
_attributes
->
y
->
info
()
<<
std
::
endl
;
oss
<<
std
::
scientific
<<
std
::
setprecision
(
2
);
oss
<<
"- rtol="
<<
_rtol
<<
", atol="
<<
_atol
<<
std
::
endl
;
return
oss
.
str
();
}
Test
::~
Test
()
{
delete
_attributes
;
}
}
// namespace infiniop_test::sigmoid
src/infiniop-test/src/ops/topkrouter.cpp
0 → 100644
View file @
ed530e11
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>
namespace
infiniop_test
::
topkrouter
{
struct
Test
::
Attributes
{
std
::
shared_ptr
<
Tensor
>
values
;
std
::
shared_ptr
<
Tensor
>
indices
;
std
::
shared_ptr
<
Tensor
>
x
;
std
::
shared_ptr
<
Tensor
>
correction_bias
;
float
routed_scaling_factor
;
int
topk
;
std
::
shared_ptr
<
Tensor
>
lable_values
;
std
::
shared_ptr
<
Tensor
>
lable_indices
;
};
std
::
shared_ptr
<
Test
>
Test
::
build
(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
uint8_t
>>
attributes
,
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
tensors
,
double
rtol
,
double
atol
)
{
auto
test
=
std
::
shared_ptr
<
Test
>
(
new
Test
(
rtol
,
atol
));
test
->
_attributes
=
new
Attributes
();
if
(
attributes
.
find
(
"routed_scaling_factor"
)
==
attributes
.
end
()
||
attributes
.
find
(
"topk"
)
==
attributes
.
end
()
||
tensors
.
find
(
"values"
)
==
tensors
.
end
()
||
tensors
.
find
(
"indices"
)
==
tensors
.
end
()
||
tensors
.
find
(
"x"
)
==
tensors
.
end
()
||
tensors
.
find
(
"correction_bias"
)
==
tensors
.
end
()
||
tensors
.
find
(
"lable_values"
)
==
tensors
.
end
()
||
tensors
.
find
(
"lable_indices"
)
==
tensors
.
end
())
{
throw
std
::
runtime_error
(
"Invalid Test: Missing attributes or tensors"
);
}
test
->
_attributes
->
values
=
tensors
[
"values"
];
test
->
_attributes
->
indices
=
tensors
[
"indices"
];
test
->
_attributes
->
x
=
tensors
[
"x"
];
test
->
_attributes
->
correction_bias
=
tensors
[
"correction_bias"
];
test
->
_attributes
->
routed_scaling_factor
=
*
reinterpret_cast
<
float
*>
(
attributes
[
"routed_scaling_factor"
].
data
());
test
->
_attributes
->
topk
=
*
reinterpret_cast
<
int
*>
(
attributes
[
"topk"
].
data
());
test
->
_attributes
->
lable_values
=
tensors
[
"lable_values"
];
test
->
_attributes
->
lable_indices
=
tensors
[
"lable_indices"
];
return
test
;
}
std
::
shared_ptr
<
infiniop_test
::
Result
>
Test
::
run
(
infiniopHandle_t
handle
,
infiniDevice_t
device
,
int
device_id
,
size_t
warm_ups
,
size_t
iterations
)
{
infiniopTopkrouterDescriptor_t
op_desc
;
CHECK_OR
(
infiniopCreateTopkrouterDescriptor
(
handle
,
&
op_desc
,
_attributes
->
x
->
desc
(),
_attributes
->
correction_bias
->
desc
()),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to create Topkrouter descriptor"
));
//
auto
values
=
_attributes
->
values
->
to
(
device
,
device_id
);
auto
indices
=
_attributes
->
indices
->
to
(
device
,
device_id
);
auto
x
=
_attributes
->
x
->
to
(
device
,
device_id
);
auto
correction_bias
=
_attributes
->
correction_bias
->
to
(
device
,
device_id
);
float
routed_scaling_factor
=
_attributes
->
routed_scaling_factor
;
int
topk
=
_attributes
->
topk
;
size_t
workspace_size
;
CHECK_OR
(
infiniopGetTopkrouterWorkspaceSize
(
op_desc
,
&
workspace_size
),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to get workspace size"
));
void
*
workspace
=
nullptr
;
if
(
workspace_size
>
0
)
{
CHECK_OR
(
infinirtMalloc
(
&
workspace
,
workspace_size
),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to allocate workspace"
));
}
CHECK_OR
(
infiniopTopkrouter
(
op_desc
,
workspace
,
workspace_size
,
values
->
data
(),
indices
->
data
(),
x
->
data
(),
correction_bias
->
data
(),
routed_scaling_factor
,
topk
,
nullptr
),
return
TEST_FAILED
(
OP_EXECUTION_FAILED
,
"Topkrouter execution failed"
));
try
{
allClose
(
values
,
_attributes
->
lable_values
,
_rtol
,
_atol
);
allClose
(
indices
,
_attributes
->
lable_indices
,
_rtol
,
_atol
);
}
catch
(
const
std
::
exception
&
e
)
{
return
TEST_FAILED
(
RESULT_INCORRECT
,
e
.
what
());
}
double
elapsed_time
=
0.
;
elapsed_time
=
benchmark
(
[
=
]()
{
infiniopTopkrouter
(
op_desc
,
workspace
,
workspace_size
,
values
->
data
(),
indices
->
data
(),
x
->
data
(),
correction_bias
->
data
(),
routed_scaling_factor
,
topk
,
nullptr
);
},
warm_ups
,
iterations
);
if
(
workspace
!=
nullptr
)
{
infinirtFree
(
workspace
);
}
return
TEST_PASSED
(
elapsed_time
);
}
std
::
vector
<
std
::
string
>
Test
::
attribute_names
()
{
return
{
"routed_scaling_factor"
,
"topk"
};
}
std
::
vector
<
std
::
string
>
Test
::
tensor_names
()
{
return
{
"values"
,
"indices"
,
"x"
,
"correction_bias"
,
"lable_values"
,
"lable_indices"
};
}
std
::
vector
<
std
::
string
>
Test
::
output_names
()
{
return
{
"values"
,
"indices"
};
}
std
::
string
Test
::
toString
()
const
{
std
::
ostringstream
oss
;
oss
<<
op_name
()
<<
std
::
endl
;
oss
<<
"- routed_scaling_factor="
<<
_attributes
->
routed_scaling_factor
<<
std
::
endl
;
oss
<<
"- topk="
<<
_attributes
->
topk
<<
std
::
endl
;
oss
<<
"- values: "
<<
_attributes
->
values
->
info
()
<<
std
::
endl
;
oss
<<
"- indices: "
<<
_attributes
->
indices
->
info
()
<<
std
::
endl
;
oss
<<
"- x: "
<<
_attributes
->
x
->
info
()
<<
std
::
endl
;
oss
<<
"- correction_bias: "
<<
_attributes
->
correction_bias
->
info
()
<<
std
::
endl
;
oss
<<
"- lable_values: "
<<
_attributes
->
lable_values
->
info
()
<<
std
::
endl
;
oss
<<
"- lable_indices: "
<<
_attributes
->
lable_indices
->
info
()
<<
std
::
endl
;
oss
<<
std
::
scientific
<<
std
::
setprecision
(
2
);
oss
<<
"- rtol="
<<
_rtol
<<
", atol="
<<
_atol
<<
std
::
endl
;
return
oss
.
str
();
}
Test
::~
Test
()
{
delete
_attributes
;
}
}
// namespace infiniop_test::topkrouter
src/infiniop-test/src/ops/topksoftmax.cpp
0 → 100644
View file @
ed530e11
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>
namespace
infiniop_test
::
topksoftmax
{
struct
Test
::
Attributes
{
std
::
shared_ptr
<
Tensor
>
values
;
std
::
shared_ptr
<
Tensor
>
indices
;
std
::
shared_ptr
<
Tensor
>
x
;
int
topk
;
bool
norm
;
std
::
shared_ptr
<
Tensor
>
lable_values
;
std
::
shared_ptr
<
Tensor
>
lable_indices
;
};
std
::
shared_ptr
<
Test
>
Test
::
build
(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
uint8_t
>>
attributes
,
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
tensors
,
double
rtol
,
double
atol
)
{
auto
test
=
std
::
shared_ptr
<
Test
>
(
new
Test
(
rtol
,
atol
));
test
->
_attributes
=
new
Attributes
();
if
(
attributes
.
find
(
"topk"
)
==
attributes
.
end
()
||
attributes
.
find
(
"norm"
)
==
attributes
.
end
()
||
tensors
.
find
(
"values"
)
==
tensors
.
end
()
||
tensors
.
find
(
"indices"
)
==
tensors
.
end
()
||
tensors
.
find
(
"x"
)
==
tensors
.
end
()
||
tensors
.
find
(
"lable_values"
)
==
tensors
.
end
()
||
tensors
.
find
(
"lable_indices"
)
==
tensors
.
end
())
{
throw
std
::
runtime_error
(
"Invalid Test: Missing attributes or tensors"
);
}
test
->
_attributes
->
values
=
tensors
[
"values"
];
test
->
_attributes
->
indices
=
tensors
[
"indices"
];
test
->
_attributes
->
x
=
tensors
[
"x"
];
test
->
_attributes
->
topk
=
*
reinterpret_cast
<
int
*>
(
attributes
[
"topk"
].
data
());
test
->
_attributes
->
norm
=
*
reinterpret_cast
<
bool
*>
(
attributes
[
"norm"
].
data
());
test
->
_attributes
->
lable_values
=
tensors
[
"lable_values"
];
test
->
_attributes
->
lable_indices
=
tensors
[
"lable_indices"
];
return
test
;
}
std
::
shared_ptr
<
infiniop_test
::
Result
>
Test
::
run
(
infiniopHandle_t
handle
,
infiniDevice_t
device
,
int
device_id
,
size_t
warm_ups
,
size_t
iterations
)
{
infiniopTopksoftmaxDescriptor_t
op_desc
;
CHECK_OR
(
infiniopCreateTopksoftmaxDescriptor
(
handle
,
&
op_desc
,
_attributes
->
x
->
desc
()),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to create Topksoftmax descriptor"
));
//
auto
values
=
_attributes
->
values
->
to
(
device
,
device_id
);
auto
indices
=
_attributes
->
indices
->
to
(
device
,
device_id
);
auto
x
=
_attributes
->
x
->
to
(
device
,
device_id
);
int
topk
=
_attributes
->
topk
;
bool
norm
=
_attributes
->
norm
;
size_t
workspace_size
;
CHECK_OR
(
infiniopGetTopksoftmaxWorkspaceSize
(
op_desc
,
&
workspace_size
),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to get workspace size"
));
void
*
workspace
=
nullptr
;
if
(
workspace_size
>
0
)
{
CHECK_OR
(
infinirtMalloc
(
&
workspace
,
workspace_size
),
return
TEST_FAILED
(
OP_CREATION_FAILED
,
"Failed to allocate workspace"
));
}
CHECK_OR
(
infiniopTopksoftmax
(
op_desc
,
workspace
,
workspace_size
,
values
->
data
(),
indices
->
data
(),
x
->
data
(),
topk
,
norm
,
nullptr
),
return
TEST_FAILED
(
OP_EXECUTION_FAILED
,
"Topksoftmax execution failed"
));
try
{
allClose
(
values
,
_attributes
->
lable_values
,
_rtol
,
_atol
);
allClose
(
indices
,
_attributes
->
lable_indices
,
_rtol
,
_atol
);
}
catch
(
const
std
::
exception
&
e
)
{
return
TEST_FAILED
(
RESULT_INCORRECT
,
e
.
what
());
}
double
elapsed_time
=
0.
;
elapsed_time
=
benchmark
(
[
=
]()
{
infiniopTopksoftmax
(
op_desc
,
workspace
,
workspace_size
,
values
->
data
(),
indices
->
data
(),
x
->
data
(),
topk
,
norm
,
nullptr
);
},
warm_ups
,
iterations
);
if
(
workspace
!=
nullptr
)
{
infinirtFree
(
workspace
);
}
return
TEST_PASSED
(
elapsed_time
);
}
std
::
vector
<
std
::
string
>
Test
::
attribute_names
()
{
return
{
"topk"
,
"norm"
};
}
std
::
vector
<
std
::
string
>
Test
::
tensor_names
()
{
return
{
"values"
,
"indices"
,
"x"
,
"lable_values"
,
"lable_indices"
};
}
std
::
vector
<
std
::
string
>
Test
::
output_names
()
{
return
{
"values"
,
"indices"
};
}
std
::
string
Test
::
toString
()
const
{
std
::
ostringstream
oss
;
oss
<<
op_name
()
<<
std
::
endl
;
oss
<<
"- topk="
<<
_attributes
->
topk
<<
std
::
endl
;
oss
<<
"- norm="
<<
_attributes
->
norm
<<
std
::
endl
;
oss
<<
"- values: "
<<
_attributes
->
values
->
info
()
<<
std
::
endl
;
oss
<<
"- indices: "
<<
_attributes
->
indices
->
info
()
<<
std
::
endl
;
oss
<<
"- x: "
<<
_attributes
->
x
->
info
()
<<
std
::
endl
;
oss
<<
"- lable_values: "
<<
_attributes
->
lable_values
->
info
()
<<
std
::
endl
;
oss
<<
"- lable_indices: "
<<
_attributes
->
lable_indices
->
info
()
<<
std
::
endl
;
oss
<<
std
::
scientific
<<
std
::
setprecision
(
2
);
oss
<<
"- rtol="
<<
_rtol
<<
", atol="
<<
_atol
<<
std
::
endl
;
return
oss
.
str
();
}
Test
::~
Test
()
{
delete
_attributes
;
}
}
// namespace infiniop_test::topksoftmax
src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc
0 → 100644
View file @
ed530e11
#include "sigmoid_cpu.h"
namespace
op
::
sigmoid
::
cpu
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cpu
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
x_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
y_shape
=
out_desc
->
shape
();
const
auto
&
x_shape
=
x_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
,
INFINI_DTYPE_BF16
);
CHECK_SAME_SHAPE
(
y_shape
,
x_shape
);
// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
SigmoidOp
,
fp16_t
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
SigmoidOp
,
float
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
SigmoidOp
,
double
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_BF16
:
return
_device_info
->
calculate
<
SigmoidOp
,
bf16_t
>
(
_info
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::sigmoid::cpu
src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h
0 → 100644
View file @
ed530e11
#ifndef __SIGMOID_CPU_H__
#define __SIGMOID_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR
(
sigmoid
,
cpu
)
namespace
op
::
sigmoid
::
cpu
{
typedef
struct
SigmoidOp
{
public:
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
T
operator
()(
const
T
&
x
)
const
{
return
T
(
1
)
/
(
T
(
1
)
+
std
::
exp
(
-
x
));
}
}
SigmoidOp
;
}
// namespace op::sigmoid::cpu
#endif // __SIGMOID_CPU_H__
src/infiniop/ops/sigmoid/cuda/kernel.cuh
0 → 100644
View file @
ed530e11
#ifndef __SIDMOID_CUDA_H__
#define __SIDMOID_CUDA_H__
#include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace
op
::
sigmoid
::
cuda
{
typedef
struct
SigmoidOp
{
public:
static
constexpr
size_t
num_inputs
=
1
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
x
)
const
{
// sigmoid(x) = 1 / (1 + exp(-x))
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
half2
denominator
=
__hadd2
(
make_half2
(
1
,
1
),
h2exp
(
__hneg2
(
x
)));
return
h2rcp
(
denominator
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
half
denominator
=
__hadd
(
__float2half
(
1.0
f
),
hexp
(
__hneg
(
x
)));
return
hrcp
(
denominator
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
__nv_bfloat16
>
)
{
__nv_bfloat16
denominator
=
__float2bfloat16
(
__fadd_rn
(
1.0
f
,
__expf
(
__bfloat162float
(
-
x
))));
return
__float2bfloat16
(
1.0
f
)
/
denominator
;
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
float
denominator
=
__fadd_rn
(
1.0
f
,
__expf
(
-
x
));
return
__frcp_rn
(
denominator
);
}
else
{
// double
return
1.0
/
(
1.0
+
exp
(
-
x
));
}
}
}
SigmoidOp
;
}
// namespace op::sigmoid::cuda
#endif // __SIDMOID_CUDA_H__
src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu
0 → 100644
View file @
ed530e11
#include "../cuda/kernel.cuh"
#include "sigmoid_nvidia.cuh"
namespace
op
::
sigmoid
::
nvidia
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
x_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
y_shape
=
out_desc
->
shape
();
const
auto
&
x_shape
=
x_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
,
INFINI_DTYPE_BF16
);
CHECK_SAME_SHAPE
(
y_shape
,
x_shape
);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
)
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
256
,
cuda
::
SigmoidOp
,
half
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_BF16
:
return
_device_info
->
calculate
<
256
,
cuda
::
SigmoidOp
,
__nv_bfloat16
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
256
,
cuda
::
SigmoidOp
,
float
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
256
,
cuda
::
SigmoidOp
,
double
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::sigmoid::nvidia
src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh
0 → 100644
View file @
ed530e11
#ifndef __SIGMOID_CUDA_API_H__
#define __SIGMOID_CUDA_API_H__
#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh"
ELEMENTWISE_DESCRIPTOR
(
sigmoid
,
nvidia
)
#endif // __SIGMOID_CUDA_API_H__
src/infiniop/ops/sigmoid/operator.cc
0 → 100644
View file @
ed530e11
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/sigmoid.h"
#ifdef ENABLE_CPU_API
#include "cpu/sigmoid_cpu.h"
#endif
#ifdef ENABLE_NVIDIA_API
#include "nvidia/sigmoid_nvidia.cuh"
#endif
__C
infiniStatus_t
infiniopCreateSigmoidDescriptor
(
infiniopHandle_t
handle
,
infiniopSigmoidDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::sigmoid::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::sigmoid::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
{x_desc})
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetSigmoidWorkspaceSize
(
infiniopSigmoidDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::sigmoid::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopSigmoid
(
infiniopSigmoidDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::sigmoid::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, {x}, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroySigmoidDescriptor
(
infiniopSigmoidDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::sigmoid::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc
View file @
ed530e11
#include "topkrouter_cpu.h"
#include "topkrouter_cpu.h"
#include "../../../../utils.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
#include "../../../reduce/cpu/reduce.h"
#include <algorithm>
namespace
op
::
topkrouter
::
cpu
{
namespace
op
::
topkrouter
::
cpu
{
Descriptor
::~
Descriptor
()
{
}
Descriptor
::~
Descriptor
()
{}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
x_desc
,
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
correction_bias_desc
)
{
infiniopTensorDescriptor_t
correction_bias_desc
)
{
auto
result
=
TopkrouterInfo
::
create
(
x_desc
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
if
(
info
.
x_strides
[
1
]
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
*
desc_ptr
=
new
Descriptor
(
nullptr
,
std
::
move
(
info
),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
>
inline
float
sigmoid_func
(
T
x
)
{
float
value
;
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
value
=
_f16_to_f32
(
x
);
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf16_t
>::
value
)
{
value
=
_bf16_to_f32
(
x
);
}
else
{
value
=
x
;
}
return
1.0
f
/
(
1.0
f
+
std
::
exp
(
-
value
));
}
template
<
typename
T
>
void
topkrouter_cpu_one_token
(
float
*
values_input
,
// 输出数据
int
*
indices_input
,
// 输出索引
const
T
*
x_input
,
// 输入数据
std
::
vector
<
std
::
pair
<
float
,
size_t
>>
&
value_index_arr
,
// 输入数据
const
float
*
correction_bias
,
const
float
routed_scaling_factor
,
const
size_t
topk
,
const
size_t
width
,
const
size_t
n_routed_experts
,
const
size_t
n_group
,
const
size_t
topk_group
,
const
bool
norm_topk_prob
)
{
// ------------------------------------------------------ //
// 对输入数据做 sigmoid //
// ------------------------------------------------------ //
for
(
size_t
i
=
0
;
i
<
width
;
++
i
)
{
value_index_arr
[
i
].
first
=
sigmoid_func
(
value_index_arr
[
i
].
first
);
}
// ------------------------------------------------------ //
// 再加偏置 //
// ------------------------------------------------------ //
for
(
size_t
i
=
0
;
i
<
width
;
++
i
)
{
value_index_arr
[
i
].
first
+=
correction_bias
[
i
];
}
// ----------------------------------------------------------- //
// 分为 n_group 组,找出每组的最大值 //
// ----------------------------------------------------------- //
std
::
vector
<
std
::
pair
<
float
,
size_t
>>
value_index_group
;
value_index_group
.
resize
(
n_group
);
const
size_t
group_size
=
width
/
n_group
;
// group_size表示每个组的元素数量
for
(
size_t
igroup
=
0
;
igroup
<
n_group
;
++
igroup
)
{
std
::
vector
<
std
::
pair
<
float
,
size_t
>>
value_index_warp
;
value_index_warp
.
resize
(
group_size
);
auto
it
=
value_index_arr
.
begin
()
+
igroup
*
group_size
;
for
(
size_t
i
=
0
;
i
<
group_size
;
++
i
)
{
value_index_warp
[
i
]
=
{(
it
++
)
->
first
,
i
};
}
// 每个group中的数据,进行排序
std
::
sort
(
value_index_warp
.
begin
(),
value_index_warp
.
end
(),
[](
const
std
::
pair
<
float
,
size_t
>
&
a
,
const
std
::
pair
<
float
,
size_t
>
&
b
)
{
return
a
.
first
>
b
.
first
;
});
// 取前两个的和,作为最大值
value_index_group
[
igroup
]
=
{
value_index_warp
[
0
].
first
+
value_index_warp
[
1
].
first
,
igroup
};
}
// ------------------------------------------------------------------ //
// 对 value_index_group 的数据, 再选前 topk_group 个 //
// ------------------------------------------------------------------ //
std
::
sort
(
value_index_group
.
begin
(),
value_index_group
.
end
(),
[](
const
std
::
pair
<
float
,
size_t
>
&
a
,
const
std
::
pair
<
float
,
size_t
>
&
b
)
{
return
a
.
first
>
b
.
first
;
});
return
INFINI_STATUS_NOT_IMPLEMENTED
;
std
::
vector
<
bool
>
group_mask
;
group_mask
.
resize
(
n_group
,
false
);
for
(
size_t
i
=
0
;
i
<
topk_group
;
++
i
)
{
size_t
index
=
value_index_group
[
i
].
second
;
group_mask
[
index
]
=
true
;
}
// ------------------------------------------------------------------ //
// 根据group_mask,false的组的数值置0 //
// ------------------------------------------------------------------ //
for
(
size_t
igroup
=
0
;
igroup
<
n_group
;
++
igroup
)
{
if
(
group_mask
[
igroup
])
{
continue
;
}
auto
it
=
value_index_arr
.
begin
()
+
igroup
*
group_size
;
for
(
size_t
i
=
0
;
i
<
group_size
;
++
i
)
{
(
it
++
)
->
first
=
0.0
f
;
}
}
// ------------------------------------------------------------------ //
// 最后整体做topk //
// ------------------------------------------------------------------ //
std
::
sort
(
value_index_arr
.
begin
(),
value_index_arr
.
end
(),
[](
const
std
::
pair
<
float
,
size_t
>
&
a
,
const
std
::
pair
<
float
,
size_t
>
&
b
)
{
return
a
.
first
>
b
.
first
;
});
// ----------------------------------------------------------- //
// 取topk个数据 //
// ----------------------------------------------------------- //
float
exp_sum
=
1e-9
f
;
for
(
size_t
i
=
0
;
i
<
topk
;
++
i
)
{
size_t
index
=
value_index_arr
[
i
].
second
;
float
exp_value
=
sigmoid_func
(
x_input
[
index
]);
values_input
[
i
]
=
exp_value
;
indices_input
[
i
]
=
static_cast
<
int
>
(
index
);
exp_sum
+=
exp_value
;
}
// ----------------------------------------------------------- //
// 归一化 //
// ----------------------------------------------------------- //
if
(
norm_topk_prob
)
{
for
(
size_t
i
=
0
;
i
<
topk
;
++
i
)
{
values_input
[
i
]
=
routed_scaling_factor
*
values_input
[
i
]
/
exp_sum
;
}
}
}
}
infiniStatus_t
Descriptor
::
calculate
(
template
<
typename
T
>
void
*
workspace
,
infiniStatus_t
topkrouter_cpu_func
(
float
*
values
,
int
*
indices
,
size_t
workspace_size
,
const
T
*
x
,
float
*
values
,
int
*
indices
,
void
*
x
,
float
*
correction_bias
,
const
float
*
correction_bias
,
const
float
routed_scaling_factor
,
const
size_t
topk
,
float
routed_scaling_factor
,
const
size_t
N
,
const
size_t
width
,
const
size_t
n_routed_experts
=
256
,
size_t
topk
,
const
size_t
n_group
=
8
,
const
size_t
topk_group
=
4
,
const
bool
norm_topk_prob
=
true
)
{
/*
O-----------> width 地址连续
|
|
N
*/
for
(
size_t
n
=
0
;
n
<
N
;
++
n
)
{
float
*
values_input
=
values
+
n
*
topk
;
int
*
indices_input
=
indices
+
n
*
topk
;
const
T
*
x_input
=
x
+
n
*
width
;
std
::
vector
<
std
::
pair
<
float
,
size_t
>>
value_index_arr
;
value_index_arr
.
resize
(
width
);
for
(
size_t
i
=
0
;
i
<
width
;
++
i
)
{
float
temp
;
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
temp
=
_f16_to_f32
(
x_input
[
i
]);
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf16_t
>::
value
)
{
temp
=
_bf16_to_f32
(
x_input
[
i
]);
}
else
{
temp
=
x_input
[
i
];
}
value_index_arr
[
i
]
=
{
temp
,
i
};
}
topkrouter_cpu_one_token
<
T
>
(
values_input
,
indices_input
,
x_input
,
value_index_arr
,
correction_bias
,
routed_scaling_factor
,
topk
,
width
,
n_routed_experts
,
n_group
,
topk_group
,
norm_topk_prob
);
}
return
INFINI_STATUS_SUCCESS
;
}
// namespace op::topkrouter::cpu
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
float
*
values
,
int
*
indices
,
const
void
*
x
,
const
float
*
correction_bias
,
const
float
routed_scaling_factor
,
const
size_t
topk
,
void
*
stream
)
const
{
void
*
stream
)
const
{
size_t
N
=
_info
.
N
;
size_t
width
=
_info
.
width
;
// 下面是 deepseek的config.json的超参数
const
size_t
n_routed_experts
=
256
;
const
size_t
n_group
=
8
;
const
size_t
topk_group
=
4
;
const
bool
norm_topk_prob
=
true
;
if
((
width
!=
n_routed_experts
)
||
(
width
%
n_group
!=
0
)
||
(
256
!=
width
))
{
return
INFINI_STATUS_BAD_PARAM
;
}
if
(
_info
.
xtype
==
INFINI_DTYPE_F32
)
{
topkrouter_cpu_func
(
values
,
indices
,
static_cast
<
const
float
*>
(
x
),
correction_bias
,
routed_scaling_factor
,
topk
,
N
,
width
,
n_routed_experts
,
n_group
,
topk_group
,
norm_topk_prob
);
}
else
if
(
_info
.
xtype
==
INFINI_DTYPE_F16
)
{
topkrouter_cpu_func
(
values
,
indices
,
static_cast
<
const
fp16_t
*>
(
x
),
correction_bias
,
routed_scaling_factor
,
topk
,
N
,
width
,
n_routed_experts
,
n_group
,
topk_group
,
norm_topk_prob
);
}
else
if
(
_info
.
xtype
==
INFINI_DTYPE_BF16
)
{
topkrouter_cpu_func
(
values
,
indices
,
static_cast
<
const
bf16_t
*>
(
x
),
correction_bias
,
routed_scaling_factor
,
topk
,
N
,
width
,
n_routed_experts
,
n_group
,
topk_group
,
norm_topk_prob
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_
NOT_IMPLEMENTED
;
return
INFINI_STATUS_
SUCCESS
;
}
}
}
// namespace op::topkrouter::cpu
}
// namespace op::topkrouter::cpu
src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.h
View file @
ed530e11
#ifndef __T
opkrouter
_CPU_H__
#ifndef __T
OPKTOUTER
_CPU_H__
#define __T
opkrouter
_CPU_H__
#define __T
OPKTOUTER
_CPU_H__
#include "../topkrouter.h"
#include "../topkrouter.h"
DESCRIPTOR
(
cpu
)
DESCRIPTOR
(
cpu
)
#endif
#endif
src/infiniop/ops/topkrouter/cuda/kernel.cuh
View file @
ed530e11
#ifndef _T
opkrouter
_KERNEL_CUH__
#ifndef _T
OPKROUTER
_KERNEL_CUH__
#define _T
opkrouter
_KERNEL_CUH__
#define _T
OPKROUTER
_KERNEL_CUH__
#include <cfloat>
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_radix_sort.cuh>
...
@@ -36,17 +36,14 @@ struct CustomLess {
...
@@ -36,17 +36,14 @@ struct CustomLess {
}
}
};
};
//
// deepseek的topk
//
template
<
typename
T
,
int
BLOCK_THREADS
=
256
>
template
<
typename
T
,
int
BLOCK_THREADS
=
256
>
__global__
void
topkrouter_kernel
(
float
*
values_topk
,
// 输出
值
, 形状[N, topk]
__global__
void
topkrouter_kernel
(
float
*
values_topk
,
// 输出
数据
, 形状[N, topk]
int
*
indices_topk
,
// 输出索引, 形状[N, topk]
int
*
indices_topk
,
// 输出索引, 形状[N, topk]
T
*
input
,
// 输入数据 [N, width]
const
T
*
input
,
// 输入数据 [N, width]
float
*
d_correction_bias
,
// 输入数据 [width]
const
float
*
d_correction_bias
,
// 输入数据 [width]
float
routed_scaling_factor
,
//
const
float
routed_scaling_factor
,
const
size_t
N
,
// 总行数,toen数量
const
size_t
N
,
const
size_t
width
,
// 每行元素数量
const
size_t
width
,
const
size_t
topk
const
size_t
topk
)
{
)
{
...
@@ -99,7 +96,6 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出值, 形
...
@@ -99,7 +96,6 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出值, 形
// ----------------------------------------------------------- //
// ----------------------------------------------------------- //
// 每个组中,前两个数据的和 //
// 每个组中,前两个数据的和 //
// ----------------------------------------------------------- //
// ----------------------------------------------------------- //
__syncthreads
();
__syncthreads
();
if
(
0
==
lane_id
)
{
if
(
0
==
lane_id
)
{
share_data_group
[
warp_id
]
=
share_data
[
warp_id
*
warp_threads
]
+
share_data
[
warp_id
*
warp_threads
+
1
];
share_data_group
[
warp_id
]
=
share_data
[
warp_id
*
warp_threads
]
+
share_data
[
warp_id
*
warp_threads
+
1
];
...
@@ -116,8 +112,10 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出值, 形
...
@@ -116,8 +112,10 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出值, 形
thread_indices
[
0
]
=
lane_id
;
thread_indices
[
0
]
=
lane_id
;
}
}
{
__shared__
typename
WarpMergeSortT
::
TempStorage
temp_storage
[
1
];
__shared__
typename
WarpMergeSortT
::
TempStorage
temp_storage
[
1
];
WarpMergeSortT
(
temp_storage
[
0
]).
Sort
(
thread_values
,
thread_indices
,
CustomLess
());
WarpMergeSortT
(
temp_storage
[
0
]).
Sort
(
thread_values
,
thread_indices
,
CustomLess
());
}
if
(
lane_id
<
4
)
{
if
(
lane_id
<
4
)
{
int
indices
=
thread_indices
[
0
];
int
indices
=
thread_indices
[
0
];
share_data_group_mask
[
indices
]
=
1.0
f
;
share_data_group_mask
[
indices
]
=
1.0
f
;
...
@@ -147,13 +145,13 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出值, 形
...
@@ -147,13 +145,13 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出值, 形
int
index
=
thread_indices
[
0
];
int
index
=
thread_indices
[
0
];
value
=
sigmoid_func
(
data_input
[
index
]);
value
=
sigmoid_func
(
data_input
[
index
]);
}
}
{
typedef
cub
::
WarpReduce
<
float
,
warp_threads
>
WarpReduce
;
typedef
cub
::
WarpReduce
<
float
,
warp_threads
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
;
// 使用有效项group 进行部分归约
float
warp_sum
=
WarpReduce
(
temp_storage
).
Sum
(
value
);
float
warp_sum
=
WarpReduce
(
temp_storage
).
Sum
(
value
);
if
(
0
==
tid
)
{
if
(
0
==
tid
)
{
share_sum
=
warp_sum
+
1e-20
;
share_sum
=
warp_sum
+
1e-9
f
;
}
}
}
__syncwarp
();
__syncwarp
();
...
...
src/infiniop/ops/topkrouter/info.h
View file @
ed530e11
#ifndef __
topkrouter
_INFO_H__
#ifndef __
TOPKROUTER
_INFO_H__
#define __
topkrouter
_INFO_H__
#define __
TOPKROUTER
_INFO_H__
#include "../../../utils.h"
#include "../../../utils.h"
#include "../../tensor.h"
#include "../../tensor.h"
...
...
src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu
View file @
ed530e11
...
@@ -40,9 +40,9 @@ infiniStatus_t Descriptor::create(
...
@@ -40,9 +40,9 @@ infiniStatus_t Descriptor::create(
namespace
{
namespace
{
template
<
int
BLOCK_SIZE
=
128
>
template
<
int
BLOCK_SIZE
=
128
>
infiniStatus_t
launch_topkrouter
(
float
*
d_values_out
,
int
*
d_indices_out
,
void
*
d_input
,
float
*
d_correction_bias
,
float
routed_scaling_factor
,
infiniStatus_t
launch_topkrouter
(
float
*
d_values_out
,
int
*
d_indices_out
,
const
void
*
d_input
,
const
float
*
d_correction_bias
,
size_t
N
,
size_t
width
,
size_t
topk
,
infiniDtype_t
xtype
,
cudaStream_t
stream
)
{
const
float
routed_scaling_factor
,
const
size_t
N
,
const
size_t
width
,
const
size_t
topk
,
infiniDtype_t
xtype
,
cudaStream_t
stream
)
{
const
int
block_threads
=
BLOCK_SIZE
;
const
int
block_threads
=
BLOCK_SIZE
;
dim3
blocks
(
N
);
dim3
blocks
(
N
);
dim3
threads
(
block_threads
);
dim3
threads
(
block_threads
);
...
@@ -63,9 +63,15 @@ infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, void *
...
@@ -63,9 +63,15 @@ infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, void *
};
// namespace
};
// namespace
infiniStatus_t
Descriptor
::
calculate
(
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
workspace
,
float
*
values
,
int
*
indices
,
void
*
x
,
float
*
correction_bias
,
float
routed_scaling_factor
,
size_t
topk
,
void
*
stream
)
const
{
size_t
workspace_size
,
float
*
values
,
int
*
indices
,
const
void
*
x
,
const
float
*
correction_bias
,
const
float
routed_scaling_factor
,
const
size_t
topk
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
}
...
@@ -76,13 +82,12 @@ infiniStatus_t Descriptor::calculate(
...
@@ -76,13 +82,12 @@ infiniStatus_t Descriptor::calculate(
// size_t n_routed_experts = 256;
// size_t n_routed_experts = 256;
// size_t n_group = 8;
// size_t n_group = 8;
// size_t topk_group = 4;
// size_t topk_group = 4;
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
if
(
256
==
width
)
{
if
(
256
==
width
)
{
launch_topkrouter
<
256
>
(
values
,
indices
,
x
,
correction_bias
,
routed_scaling_factor
,
N
,
width
,
topk
,
_info
.
xtype
,
cuda_stream
);
launch_topkrouter
<
256
>
(
values
,
indices
,
x
,
correction_bias
,
routed_scaling_factor
,
N
,
width
,
topk
,
_info
.
xtype
,
cuda_stream
);
}
else
{
}
else
{
return
INFINI_STATUS_
INTERNAL_ERROR
;
return
INFINI_STATUS_
BAD_PARAM
;
}
}
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment