"vscode:/vscode.git/clone" did not exist on "a046f86397c06af306d964ff40d0670bdc9c00a2"
Unverified Commit e042d7e6 authored by tc-mb's avatar tc-mb Committed by GitHub
Browse files

Add flagos in MiniCPM-o (#34126)


Signed-off-by: default avatartc-mb <caitianchi@modelbest.cn>
Signed-off-by: default avatarVincent-Xiao <vincent.xiao.me@gmail.com>
Co-authored-by: default avatarVincent-Xiao <vincent.xiao.me@gmail.com>
parent ae4e2806
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
import os
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -75,6 +76,47 @@ from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix ...@@ -75,6 +76,47 @@ from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
if os.getenv("USE_FLAGOS") == "1":
import flag_gems
FLAG_GEMS_CONFIG = [
"sort",
"sort_stable",
"layer_norm",
"clamp_",
"cos",
"embedding",
"exp",
"exponential_",
"full",
"gather",
"gelu",
"index",
"le",
"lt",
"lt_scalar",
"masked_fill_",
"max",
"ones",
"pow_scalar",
"prod_dim",
"rand_like",
"reciprocal",
"repeat",
"scatter",
"scatter_",
"sin",
"sub",
"true_divide",
"true_divide_",
"uniform_",
"where_scalar_self",
"where_self_out",
"zeros",
"zeros_like",
]
flag_gems.only_enable(record=False, include=FLAG_GEMS_CONFIG)
class MiniCPMOAudioFeatureInputs(TensorSchema): class MiniCPMOAudioFeatureInputs(TensorSchema):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment