Unverified Commit ad6f7805 authored by Yongzao's avatar Yongzao Committed by GitHub
Browse files

[torch.compile] expanding support and fix allgather compilation (#9637)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent 295a061f
......@@ -392,8 +392,12 @@ class GroupCoordinator:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
......@@ -401,6 +405,7 @@ class GroupCoordinator:
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
......
......@@ -25,6 +25,7 @@ from torch import nn
from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
......@@ -187,6 +188,7 @@ class GPTBigCodeBlock(nn.Module):
return hidden_states
@support_torch_compile
class GPTBigCodeModel(nn.Module):
def __init__(
......
......@@ -23,6 +23,7 @@ from torch import nn
from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
......@@ -174,6 +175,7 @@ class GPTJBlock(nn.Module):
return hidden_states
@support_torch_compile
class GPTJModel(nn.Module):
def __init__(
......
......@@ -23,6 +23,7 @@ from torch import nn
from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
......@@ -187,6 +188,7 @@ class GPTNeoXLayer(nn.Module):
return hidden_states
@support_torch_compile
class GPTNeoXModel(nn.Module):
def __init__(
......
......@@ -28,6 +28,7 @@ from torch import nn
from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
......@@ -254,6 +255,7 @@ class GraniteDecoderLayer(nn.Module):
return hidden_states
@support_torch_compile
class GraniteModel(nn.Module):
def __init__(
......
......@@ -7,6 +7,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -230,6 +231,7 @@ class InternLMDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile
class InternLM2Model(nn.Module):
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment