Commit 99a0c39e authored by xingjinliang's avatar xingjinliang
Browse files

同步最新代码

parent 50fe58fa
Pipeline #2152 passed with stage
File mode changed from 100755 to 100644
......@@ -13,8 +13,8 @@ from packaging.version import Version as PkgVersion
from torch import Tensor
from torch.nn.parameter import Parameter
from megatron.core import ModelParallelConfig
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
......@@ -654,6 +654,23 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
else:
kv_channels = self.config.kv_channels
self.kept_packed_seq_params = set(
field.name for field in dataclasses.fields(PackedSeqParams)
)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self.kept_packed_seq_params.discard("max_seqlen_q")
self.kept_packed_seq_params.discard("max_seqlen_kv")
if get_te_version() < PkgVersion("1.10.0"):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self.kept_packed_seq_params.discard("cu_seqlens_q_padded")
self.kept_packed_seq_params.discard("cu_seqlens_kv_padded")
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=kv_channels,
......@@ -683,7 +700,9 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
):
"""Forward."""
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
{key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params}
if packed_seq_params is not None
else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
# after init
......@@ -692,24 +711,10 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs.pop("max_seqlen_q", None)
packed_seq_kwargs.pop("max_seqlen_kv", None)
if get_te_version() < PkgVersion("1.10.0"):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0.Full support added in 1.10.0 (#1012)
packed_seq_kwargs.pop("cu_seqlens_q_padded", None)
packed_seq_kwargs.pop("cu_seqlens_kv_padded", None)
# WAR for peak memory usage.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.contiguous().transpose(0, 1) for x in (query, key, value)]
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
......@@ -1229,8 +1234,14 @@ try:
from transformer_engine.pytorch.attention import FusedRoPEFunc
def fused_apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
def fused_apply_rotary_pos_emb(
t: torch.Tensor, freqs: torch.Tensor, transpose_output_memory: bool = False
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T in `sbhd` format."""
if transpose_output_memory:
warnings.warn(
"transpose_output_memory is not supported by TE's fused RoPE and will be ignored."
)
return FusedRoPEFunc.apply(t, freqs, "sbhd")
def fused_apply_rotary_pos_emb_thd(
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
@dataclass
class CommonInferenceParams:
"""Inference parameters sent along with the prompts
For an explanation of these parameters refer to this blog https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910
"""
temperature: float = 1.0
top_k: int = 0
top_p: float = 0.0
return_log_probs: bool = False
num_tokens_to_generate: int = 30
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to inference params
Use this method to pass in a custom dictonary to add more inference parameter attributes to the instance you created. Use as follows
c = CommonInferenceParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and their values as the values.
"""
for key, value in attribute_value_pair.items():
setattr(self, key, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.sampling_params import ( # noqa: F401 # pylint: disable=unused-import
SamplingParams as CommonInferenceParams,
)
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -3,12 +3,12 @@ from typing import Dict, List
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.scheduler import Scheduler
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
......@@ -19,7 +19,7 @@ class MCoreEngine(AbstractEngine):
Supports any model that is callable (Accepts the inputs and outputs the tensor)
Args:
text_generation_controller (SimpleTextGenerationController): A text generation
text_generation_controller (TextGenerationController): A text generation
controller that will be used to define how to preprocess prompts, generate
outputs and detokenizer the output tokens.
max_batch_size : The maxinum number of requests to process at once
......@@ -29,7 +29,7 @@ class MCoreEngine(AbstractEngine):
def __init__(
self,
text_generation_controller: SimpleTextGenerationController,
text_generation_controller: TextGenerationController,
max_batch_size,
random_seed: int = None,
):
......@@ -42,7 +42,8 @@ class MCoreEngine(AbstractEngine):
prompts: List[str],
add_BOS: bool = False,
encoder_prompts: List[str] = None,
common_inference_params: CommonInferenceParams = None,
common_inference_params: SamplingParams = None,
sampling_params: SamplingParams = None,
) -> dict:
"""The megatron core inference backend generate function
......@@ -54,13 +55,19 @@ class MCoreEngine(AbstractEngine):
prompts (List[str]): All the prompts as a list of strings
add_BOS (bool): Whether to add BOS token to beginning of prompts
encoder_prompts (List[dict]): All the encoder prompts as a list of strings
common_inference_params (CommonInferenceParams): The inference parameters
common_inference_params: Deprecated. Only used for backward compatibility with
MCore <= 0.9.0. Use `sampling_params` going forward.
sampling_params (SamplingParams): The request-level sampling parameters
Returns:
List[InferenceRequest]: The output is list of inference requests containing the
generated tokens, texts and log probs if required
"""
# TODO :M core- get rng state tracker
if common_inference_params:
sampling_params = common_inference_params
if self.random_seed:
torch.random.manual_seed(self.random_seed)
......@@ -73,7 +80,7 @@ class MCoreEngine(AbstractEngine):
prompt=prompt,
prompt_tokens=prompt_tokens,
encoder_prompt=encoder_prompt,
inference_parameters=common_inference_params,
inference_parameters=sampling_params,
)
self.run_engine()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment