Commit 85762c1a authored by Xiaowei.zhang's avatar Xiaowei.zhang
Browse files

Init the main branch for aiter

parent ae0b3521
Pipeline #3505 canceled with stages
M,N,K,bias,dtype,outdtype,scaleAB
"""
* Copyright (C) 2024-2025, The vLLM team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
"""
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed
from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(
input_: torch.Tensor, open_fp8_quant: bool = False
) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_, open_fp8_quant)
def tensor_model_parallel_fused_allreduce_rmsnorm(
input_: torch.Tensor, residual_inp_: torch.Tensor, weight_: torch.Tensor, eps: float
) -> tuple[torch.Tensor, torch.Tensor]:
return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)
def tensor_model_parallel_custom_all_gather(input_: torch.Tensor) -> torch.Tensor:
return get_tp_group().custom_all_gather(input_)
def tensor_model_parallel_all_gather(
input_: torch.Tensor, use_custom: bool = False, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, use_custom, dim)
def tensor_model_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment