"...text-generation-inference.git" did not exist on "379c5c4da2494a15fb82b3b1a39fa454cb73df44"
Unverified Commit 4ff02039 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Minor fixes for Mixtral (#2015)

parent b5f882cc
...@@ -50,6 +50,9 @@ Alongside each architecture, we include some popular models that use it. ...@@ -50,6 +50,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`MistralForCausalLM` * - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct - Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
* - :code:`MixtralForCausalLM`
- Mixtral-8x7B, Mixtral-8x7B-Instruct
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
* - :code:`MPTForCausalLM` * - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import numpy as np import numpy as np
...@@ -453,10 +453,6 @@ class MixtralForCausalLM(nn.Module): ...@@ -453,10 +453,6 @@ class MixtralForCausalLM(nn.Module):
assert linear_method is None assert linear_method is None
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.tok_embeddings: Union[nn.Embedding, None] = None
self.layers: nn.ModuleList = None
self.output: Union[nn.Linear, None] = None
self.sampler: Union[Sampler, None] = None
self.tok_embeddings = VocabParallelEmbedding( self.tok_embeddings = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
...@@ -492,6 +488,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -492,6 +488,7 @@ class MixtralForCausalLM(nn.Module):
input_metadata, input_metadata,
cache_event, cache_event,
) )
hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def sample( def sample(
...@@ -499,7 +496,6 @@ class MixtralForCausalLM(nn.Module): ...@@ -499,7 +496,6 @@ class MixtralForCausalLM(nn.Module):
hidden_states: Optional[torch.Tensor], hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
hidden_states = self.norm(hidden_states)
next_tokens = self.sampler(self.output.weight, hidden_states, next_tokens = self.sampler(self.output.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment