vllm_v0.7.2.patch 40.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
diff --git a/vllm/config.py b/vllm/config.py
index 9ba49757..7e871521 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel):
     kv_buffer_size: float = 1e9
 
     # Whether this vLLM instance produces, consumes KV cache, or both. Choices
-    # are 'kv_producer', 'kv_consumer', and 'both'.
+    # are 'kv_producer', 'kv_consumer', and 'kv_both'.
     kv_role: Optional[str] = None
 
     # The rank of this vLLM instance in the KV cache transfer. Typical value:
@@ -2647,6 +2647,14 @@ class KVTransferConfig(BaseModel):
     # The KV connector port, used to build distributed connection
     kv_port: int = 14579
 
+
+    # This does not need to be set by the user. It is set by the connector.
+    kv_producers_parallel_size: Optional[int] = None
+    kv_producers_tensor_parallel_size: Optional[int] = None
+    kv_producers_pipeline_parallel_size: Optional[int] = None
+    kv_consumers_tensor_parallel_size: Optional[int] = None
+    kv_consumers_pipeline_parallel_size: Optional[int] = None
+
     def compute_hash(self) -> str:
         """
         WARNING: Whenever a new field is added to this config,
@@ -2685,6 +2693,7 @@ class KVTransferConfig(BaseModel):
                              "is set, supported roles are `kv_producer`, "
                              "`kv_consumer`, and `kv_both`")
 
+
     @property
     def is_kv_transfer_instance(self) -> bool:
         return self.kv_connector is not None and \
@@ -2706,6 +2715,18 @@ class KVTransferConfig(BaseModel):
         return self.kv_connector is not None and \
             self.kv_role in ["kv_consumer", "kv_both"]
 
+    @property
+    def tensor_parallel_multiplier(self) -> int:
+        return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size
+
+    @property
+    def kv_consumers_parallel_size(self) -> int:
+        return self.kv_parallel_size - self.kv_producers_parallel_size
+
+    @property
+    def kv_world_size(self) -> int:
+        return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier
+
 
 class CompilationLevel:
     # constants for the levels of the compilation process
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..b768e03c 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -27,13 +27,13 @@ class KVConnectorFactory:
 
     @classmethod
     def create_connector(cls, rank: int, local_rank: int,
-                         config: "VllmConfig") -> KVConnectorBase:
+                         config: "VllmConfig", world_group) -> KVConnectorBase:
         connector_name = config.kv_transfer_config.kv_connector
         if connector_name not in cls._registry:
             raise ValueError(f"Unsupported connector type: {connector_name}")
 
         connector_cls = cls._registry[connector_name]()
-        return connector_cls(rank, local_rank, config)
+        return connector_cls(rank, local_rank, config, world_group)
 
 
 # Register various connectors here.
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
77
index 2033e976..e33919c1 100644
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
@@ -8,13 +8,15 @@ MooncakePipe.
 
 But the logic can be extended to support other pipe and lookup buffer.
 """
+import re
 from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 import torch
 
 from vllm import _custom_ops as ops
-from vllm.config import VllmConfig
+from vllm.config import VllmConfig, KVTransferConfig
 from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.utils import StatelessProcessGroup
 from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
     SimpleBuffer)
 from vllm.logger import init_logger
97
@@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
98
99
100
101
102
103
104
         rank: int,
         local_rank: int,
         config: VllmConfig,
+        world_group,
     ):
 
         self.config = config.kv_transfer_config
105
@@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
         self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
         self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
 
+        self._broadcast_and_enhance_kv_config(rank, config, world_group)
+
+        self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
+        self.tp_size = config.parallel_config.tensor_parallel_size
+
         # 2 pipes for every rank in the world
-        port_offset_base = 2 * rank
+        if self.config.is_kv_producer:
+            port_offset_base = 2 * rank + 1
+        else:
+            port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1
 
+        self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
         # In disaggregated prefill, the prefill vLLM only uses send pipe
         # and the decode vLLM only uses recv pipe
         if self.config.is_kv_producer:
 
             if self.config.kv_connector == "PyNcclConnector":
                 self.producer_data_pipe = PyNcclPipe(
+                    kv_group_rank=self.kv_group_rank,
                     local_rank=local_rank,
                     config=self.config,
                     port_offset=port_offset_base,
                 )
                 self.producer_signal_pipe = PyNcclPipe(
+                    kv_group_rank=self.kv_group_rank,
                     local_rank=local_rank,
                     config=self.config,
                     port_offset=port_offset_base + 1,
138
@@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase):
139
140
141
142
143
144
145
146
147
148
149
150
151
             # its recv pipe to the send pipe of KV producder
             if self.config.kv_connector == "PyNcclConnector":
                 self.consumer_data_pipe = PyNcclPipe(
+                    kv_group_rank=self.kv_group_rank,
                     local_rank=local_rank,
                     config=self.config,
                     port_offset=port_offset_base,
                 )
                 self.consumer_signal_pipe = PyNcclPipe(
+                    kv_group_rank=self.kv_group_rank,
                     local_rank=local_rank,
                     config=self.config,
                     port_offset=port_offset_base + 1,
152
@@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase):
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
                 self.config.kv_buffer_size,
             )
 
-    def select(self, input_tokens: Optional[torch.Tensor],
+    def select(self, source_rank: int, input_tokens: Optional[torch.Tensor],
                roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
 
+        logger.info("Selecting KV caches and hidden states for source rank %d", source_rank)
+
         assert self.consumer_buffer is not None, "Please initialize the "\
             "consumer buffer before calling select."
-        return self.consumer_buffer.drop_select(input_tokens, roi)
+        return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi)
 
-    def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+    def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
                key: torch.Tensor, value: torch.Tensor,
                hidden: torch.Tensor) -> None:
 
+        logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank)
+
         assert self.producer_buffer is not None, "Please initialize the "\
             "producer buffer before calling insert."
 
-        self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
+        self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden)
 
     def send_kv_caches_and_hidden_states(
         self,
182
@@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase):
183
184
185
186
187
188
         slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
         start_layer = model_executable.model.start_layer
         end_layer = model_executable.model.end_layer
+        request_ids = list(model_input.request_ids_to_seq_ids.keys())
 
         model_config = model_executable.model.config
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
-        num_heads = int(model_config.num_key_value_heads / self.tp_size)
-        hidden_size = model_config.hidden_size
-        num_attention_heads = model_config.num_attention_heads
-        head_size = int(hidden_size / num_attention_heads)
+        is_deepseek = "deepseek" in model_config.architectures[0].lower()
+        if not is_deepseek:
+            num_heads = int(model_config.num_key_value_heads / self.tp_size)
+            hidden_size = model_config.hidden_size
+            num_attention_heads = model_config.num_attention_heads
+            head_size = int(hidden_size / num_attention_heads)
+        else:
+            num_heads = int(model_config.num_key_value_heads / self.tp_size)
+            hidden_size = model_config.hidden_size
+            num_attention_heads = model_config.num_attention_heads
+            head_size = int(4.5 * hidden_size / num_attention_heads)
 
         # query_lens contains new KV caches that are added to vLLM.
         # so we will send them to decode instance
@@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase):
208
209
210
211
212
213
214
215
216
217
             start_pos = sum(seq_lens[:idx])
             end_pos = start_pos + slen
             current_tokens = input_tokens_tensor[start_pos:end_pos]
+            current_request_id = request_ids[idx]
+            _, decode_kv_rank = self.parse_request_id(current_request_id)
+            starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config)
+
+            for target_rank in range(self.config.tensor_parallel_multiplier):
 
-            keys, values = [], []
218
+                keys, values = [], []
219
220
221
 
-            for layer_id in range(start_layer, end_layer):
-                kv_cache = kv_caches[layer_id - start_layer]
222
223
+                for layer_id in range(start_layer, end_layer):
+                    kv_cache = kv_caches[layer_id - start_layer]
224
225
226
227
228
229
230
231
232
233
234
235
 
-                key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
-                value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+                    current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
 
-                current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+                    num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
+                    head_start = target_rank * num_heads_per_rank
+                    head_end = head_start + num_heads_per_rank
 
-                keys.append(key_cache[current_slot_mapping].unsqueeze(0))
-                values.append(value_cache[current_slot_mapping].unsqueeze(0))
236
237
238
239
240
241
242
243
244
+                    if not is_deepseek:
+                        key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+                        value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+                        keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+                        values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+                    else:
+                        key_cache = kv_cache
+                        keys.append(key_cache[current_slot_mapping].unsqueeze(0))
+                        values.append(torch.empty(0))
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
 
-            keys = torch.cat(keys, dim=0)
-            values = torch.cat(values, dim=0)
+                keys = torch.cat(keys, dim=0)
+                values = torch.cat(values, dim=0)
 
-            self.insert(current_tokens,
-                        torch.ones_like(current_tokens,
-                                        dtype=bool), keys, values,
-                        hidden_or_intermediate_states[start_pos:end_pos])
+                self.insert(starting_kv_group_rank, target_rank, current_tokens,
+                            torch.ones_like(current_tokens,
+                                            dtype=bool), keys, values,
+                            hidden_or_intermediate_states[start_pos:end_pos])
 
         logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
 
262
@@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase):
263
264
265
266
267
268
269
         input_tokens_tensor = model_input.input_tokens
         seq_lens = model_input.attn_metadata.seq_lens
         slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+        request_ids = list(model_input.request_ids_to_seq_ids.keys())
 
         hidden_or_intermediate_states_for_one_req = []
 
270
271
272
273
274
275
276
277
278
279
280
@@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase):
         num_computed_tokens_list = []
         start_pos_list = []
 
+        model_config = model_executable.model.config
+        is_deepseek = "deepseek" in model_config.architectures[0].lower()
+
         # enumerate different requests
         # FIXME(Kuntai): This impl assumes that all requests are prefill.
         for idx, slen in enumerate(seq_lens):
@@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase):
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
             start_pos = sum(seq_lens[:idx])
             end_pos = start_pos + slen
             current_tokens = input_tokens_tensor[start_pos:end_pos]
+            current_request_id = request_ids[idx]
+            prefill_rank, _ = self.parse_request_id(current_request_id)
             num_tokens = slen
 
             # collecting data for rebuilding the input
             input_tokens_list.append(current_tokens)
             start_pos_list.append(start_pos)
 
-            ret = self.select(current_tokens,
+            ret = self.select(prefill_rank, current_tokens,
                               torch.ones_like(current_tokens, dtype=bool))
             if ret[0] is None:
                 # didn't find any match.
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
@@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase):
                 kv_cache = kv_caches[i - model_executable.model.start_layer]
                 layer = model_executable.model.layers[i]
 
-                key_cache, value_cache = kv_cache[0], kv_cache[1]
-                ops.reshape_and_cache_flash(
-                    keys[i - model_executable.model.start_layer].to(
-                        key_cache.device),
-                    values[i - model_executable.model.start_layer].to(
-                        value_cache.device),
-                    key_cache,
-                    value_cache,
-                    slot_mapping[start_pos:end_pos],
-                    layer.self_attn.attn.kv_cache_dtype,
-                    layer.self_attn.attn._k_scale,
-                    layer.self_attn.attn._v_scale,
-                )
+                if not is_deepseek:
+                    key_cache, value_cache = kv_cache[0], kv_cache[1]
+                    ops.reshape_and_cache_flash(
+                        keys[i - model_executable.model.start_layer].to(
+                            key_cache.device),
+                        values[i - model_executable.model.start_layer].to(
+                            value_cache.device),
+                        key_cache,
+                        value_cache,
+                        slot_mapping[start_pos:end_pos],
+                        layer.self_attn.attn.kv_cache_dtype,
+                        layer.self_attn.attn._k_scale,
+                        layer.self_attn.attn._v_scale,
+                    )
+                else:
+                    key_cache = kv_cache
+                    copy_from =keys[i - model_executable.model.start_layer].to(
+                            key_cache.device)
+                    kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
 
             hidden_or_intermediate_states_for_one_req.append(hidden)
 
@@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase):
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
             # MooncakePipe reuses data_pipe for signal_pipe, so we only have to
             # close the data_pipe.
             pass
+
+    @staticmethod
+    def parse_request_id(request_id):
+        # Regular expression to match the ranks
+        pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)"
+        
+        # Use re.search to find the pattern in the request_id
+        match = re.search(pattern, request_id)
+        
+        if match:
+            # Extract the ranks
+            prefill_rank = int(match.group(1))
+            decode_rank = int(match.group(2))
+            
+            return prefill_rank, decode_rank
+        else:
+            return None, None
+
+    
+
+    def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+        if kv_rank < config.kv_producers_parallel_size:
+            return kv_rank
+        
+        kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+        return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
+
+    def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+        if rank == 0:
+            if self.config.kv_connector == "PyNcclConnector":
+                config_group = StatelessProcessGroup.create(
+                    host=self.config.kv_ip,
+                    port=self.config.kv_port,
+                    rank=self.config.kv_rank,
+                    world_size=self.config.kv_parallel_size,
+                )
+                parallel_configs = config_group.all_gather_obj({
+                    "kv_role": self.config.kv_role,
+                    "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+                    "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+                })
+                logger.debug("parallel_configs: %s", parallel_configs)
+                kv_config_enhanced = {
+                    "kv_producers_tensor_parallel_size": None,
+                    "kv_consumers_tensor_parallel_size": None,
+                    "kv_producers_pipeline_parallel_size": None,
+                    "kv_consumers_pipeline_parallel_size": None,
+                    "kv_producers_parallel_size": 0,
+                }
+                for parallel_config in parallel_configs:
+                    kv_role = parallel_config["kv_role"]
+                    assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+                    
+                    if kv_role == "kv_producer":
+                        kv_config_enhanced["kv_producers_parallel_size"] += 1
+                    if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+                        kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+                        kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+                    else:
+                        assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+                        assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+                world_group.broadcast_object(kv_config_enhanced)
+
+            else:
+                raise NotImplementedError("MooncakeConnector is not supported in Triton Distributed vllm patch")
+        else:
+            kv_config_enhanced = world_group.broadcast_object()
+        logger.info("kv_config_enhanced: %s", kv_config_enhanced)
+
+        self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
+        self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
+        self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
+        self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+        self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
index 5e1b6235..b4506877 100644
--- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
+++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
@@ -12,7 +12,8 @@
 import threading
 import time
 from collections import deque
-from typing import Deque, List, Optional, Union
+from concurrent.futures import ThreadPoolExecutor
+from typing import Deque, List, Optional, Union, Dict
 
 import torch
 
@@ -46,7 +47,7 @@ class SimpleBuffer(KVLookupBufferBase):
         self.buffer_lock = threading.Lock()
         self.signal_pipe = signal_pipe
         self.data_pipe = data_pipe
-        self.request_handling_thread: Optional[threading.Thread] = None
+        self.request_handling_thread: Optional[ThreadPoolExecutor] = None
 
         self.normal_signal = torch.tensor([0], device="cpu")
         self.end_signal = None
@@ -57,10 +58,16 @@ class SimpleBuffer(KVLookupBufferBase):
         # tokens_roi_sender: tokens and roi of the producer (in the buffer)
         # tokens_roi_recver: tokens and roi of the consumer (query)
 
-        tokens_sender = tokens_roi_sender[0]
-        tokens_recver = tokens_roi_recver[0]
-        roi_sender = tokens_roi_sender[1]
-        roi_recver = tokens_roi_recver[1]
+        target_rank_sender = tokens_roi_sender[0]
+        target_rank_recver = tokens_roi_recver[0]
+
+        if target_rank_sender.item() != target_rank_recver.item():
+            return 0
+        
+        tokens_sender = tokens_roi_sender[1]
+        tokens_recver = tokens_roi_recver[1]
+        roi_sender = tokens_roi_sender[2]
+        roi_recver = tokens_roi_recver[2]
 
         if tokens_recver is None:
             # consumer sends an empty request
@@ -80,14 +87,14 @@ class SimpleBuffer(KVLookupBufferBase):
 
         return 0
 
-    def _send_tensor_and_dec_size(self,
-                                  tensor: Optional[torch.Tensor]) -> None:
+    def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor],
+                                  target_rank: int) -> None:
 
         assert tensor is not None, "Use self.data_pipe.send(None) instead"
         self.buffer_size -= tensor.element_size() * tensor.numel()
         if tensor.dtype == torch.bool:
             tensor = tensor.float()
-        self.data_pipe.send_tensor(tensor)
+        self.data_pipe.send_tensor(tensor, target_rank)
 
     def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
 
@@ -100,7 +107,7 @@ class SimpleBuffer(KVLookupBufferBase):
 
         raise AssertionError(f"Unknown data type {type(data)}")
 
-    def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+    def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
                        key: torch.Tensor, value: torch.Tensor,
                        hidden: torch.Tensor):
 
@@ -115,7 +122,7 @@ class SimpleBuffer(KVLookupBufferBase):
         if isinstance(hidden, torch.Tensor):
             hidden = hidden.clone()
 
-        buffer_item = [input_tokens, roi, key, value, hidden]
+        buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden]
 
         with self.buffer_lock:
             for data in buffer_item:
@@ -125,53 +132,54 @@ class SimpleBuffer(KVLookupBufferBase):
     def _is_end_signal(self, signal):
         return signal is None
 
-    def drop_select_handler(self):
+    def drop_select_handler(self, rank: int):
 
         try:
 
-            while True:
-                signal = self.signal_pipe.recv_tensor()
-                if self._is_end_signal(signal):
-                    logger.info("Received end signal!")
-                    break
-
-                input_tokens = self.data_pipe.recv_tensor()
-
-                roi = self.data_pipe.recv_tensor()
-                assert roi is not None, "Please provide the roi when sending "\
-                    "drop-select request"
-                roi = (roi > 0.5)
-                tokens_roi_recver = [input_tokens, roi]
-
-                matched_length = 0
-
-                # perform input tokens and roi matching
-                # FIXME: this matching is O(n), ideally it should be O(1)
-                # but this buffer size won't (and shouldn't) be too large so
-                # the fix is not urgent.
-                with self.buffer_lock:
-
-                    for _ in range(len(self.buffer)):
-
-                        temp_length = self._matches(self.buffer[0],
-                                                    tokens_roi_recver)
-                        if temp_length > 0:
-                            matched_length = temp_length
-                            break
-                        # rotate the element we just accessed to the end
-                        self.buffer.rotate(-1)
-
-                    if matched_length > 0:
-                        # need to clone the tensor
-                        # in case the tensor is freed before sending finishes
-                        matched_item = self.buffer.popleft()
-                        for tensor in matched_item:
-                            self._send_tensor_and_dec_size(tensor)
-
-                    else:
-                        # no match, just send None
-                        for _ in range(5):
-                            self.data_pipe.send_tensor(None)
+            signal = self.signal_pipe.recv_tensor(rank)
+            if self._is_end_signal(signal):
+                logger.info("Received end signal!")
+                return
+            target_kv_rank = self.data_pipe.recv_tensor(rank)
+            # assert target_rank.item() == rank, "Target rank does not match"\
+            #     "the rank of the drop-select handler"
+            input_tokens = self.data_pipe.recv_tensor(rank)
+            roi = self.data_pipe.recv_tensor(rank)
+            assert roi is not None, "Please provide the roi when sending "\
+                "drop-select request"
+            roi = (roi > 0.5)
+            tokens_roi_recver = [target_kv_rank, input_tokens, roi]
+
+            matched_length = 0
+
+            # perform input tokens and roi matching
+            # FIXME: this matching is O(n), ideally it should be O(1)
+            # but this buffer size won't (and shouldn't) be too large so
+            # the fix is not urgent.
+            with self.buffer_lock:
+
+                for _ in range(len(self.buffer)):
+
+                    temp_length = self._matches(self.buffer[0],
+                                                tokens_roi_recver)
+                    if temp_length > 0:
+                        matched_length = temp_length
+                        break
+                    # rotate the element we just accessed to the end
+                    self.buffer.rotate(-1)
+
+                if matched_length > 0:
+                    # need to clone the tensor
+                    # in case the tensor is freed before sending finishes
+                    matched_item = self.buffer.popleft()
+                    target_rank = matched_item[0].item()
+                    for tensor in matched_item[1:]:
+                        self._send_tensor_and_dec_size(tensor, rank)
+
+                else:
+                    # no match, just send None
+                    for _ in range(5):
+                        self.data_pipe.send_tensor(None, rank)
 
         except RuntimeError as e:
             if 'Connection closed by peer' not in str(e):
@@ -180,10 +188,10 @@ class SimpleBuffer(KVLookupBufferBase):
         logger.debug("Closing drop_select_handler")
 
     def drop_select(
-            self, input_tokens: Optional[torch.Tensor],
+            self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor],
             roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
 
-        assert self.request_handling_thread is None, \
+        assert not self.request_handling_thread, \
             "drop_select should be called by the KV cache consumer "\
             "(e.g. the decode vLLM instance)"
 
@@ -192,26 +200,28 @@ class SimpleBuffer(KVLookupBufferBase):
         if isinstance(roi, torch.Tensor):
             roi = roi.clone().float()
 
-        self.signal_pipe.send_tensor(self.normal_signal)
-        self.data_pipe.send_tensor(input_tokens)
-        self.data_pipe.send_tensor(roi)
+        self.signal_pipe.send_tensor(self.normal_signal, rank)
+
+        self.data_pipe.send_tensor(torch.tensor(kv_rank), rank)
+        self.data_pipe.send_tensor(input_tokens, rank)
+        self.data_pipe.send_tensor(roi, rank)
 
-        input_tokens = self.data_pipe.recv_tensor()
-        roi = self.data_pipe.recv_tensor()
+        input_tokens = self.data_pipe.recv_tensor(rank)
+        roi = self.data_pipe.recv_tensor(rank)
         if roi is not None:
             # convert from float tensor to bool tensor
             # as PyNccl does not support sending bool tensor
             roi = (roi > 0.5)
-        key = self.data_pipe.recv_tensor()
-        value = self.data_pipe.recv_tensor()
-        hidden = self.data_pipe.recv_tensor()
+        key = self.data_pipe.recv_tensor(rank)
+        value = self.data_pipe.recv_tensor(rank)
+        hidden = self.data_pipe.recv_tensor(rank)
 
         return [input_tokens, roi, key, value, hidden]
 
     def full_handler(self):
         time.sleep(0.001)
 
-    def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+    def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
                key: torch.Tensor, value: torch.Tensor,
                hidden: torch.Tensor) -> None:
 
@@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase):
         while self.buffer_size > self.buffer_size_threshold:
             self.full_handler()
 
-        self._add_to_buffer(input_tokens, roi, key, value, hidden)
+        self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden)
 
         # when calling the insert, the current process is a sender
         # need to launch the request handler and start listening to request.
+        target_rank_global = target_rank + kv_group_rank
         if self.request_handling_thread is None:
-            self.request_handling_thread = threading.Thread(
-                target=self.drop_select_handler)
-            self.request_handling_thread.start()
+            self.request_handling_thread = ThreadPoolExecutor(max_workers=1)
+        self.request_handling_thread.submit(self.drop_select_handler, target_rank_global)
 
     def close(self):
 
-        if hasattr(self, "request_handling_thread"
-                   ) and self.request_handling_thread is not None:
-            self.request_handling_thread.join()
+        if hasattr(self, "request_handling_thread") and self.request_handling_thread:
+            self.request_handling_thread.shutdown()
 
         else:
             # TODO: have a explicit close signal and have a explicit way to
diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py
index 40589fb3..da2829cf 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/base.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/base.py
@@ -23,7 +23,7 @@ class KVPipeBase(ABC):
     """
 
     @abstractmethod
-    def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
+    def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None:
         """Send a tensor, or None, via the pipe.
         
         Need to support sending None -- important for error handling.
@@ -41,7 +41,7 @@ class KVPipeBase(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def recv_tensor(self) -> Optional[torch.Tensor]:
+    def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
         """Receive a tensor (can be None) from the pipeline.
 
         Returns:
diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
index 7aa53d07..db10f8a0 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
@@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase):
     METADATA_DTYPE = torch.int64
 
     def __init__(self,
+                 kv_group_rank: int,
                  local_rank: int,
                  config: KVTransferConfig,
                  device: Optional[str] = None,
                  port_offset: int = 0):
         self.config = config
         self.local_rank = local_rank
-        self.kv_rank = self.config.kv_rank
+        self.kv_group_rank = kv_group_rank
         self.kv_parallel_size = self.config.kv_parallel_size
+        self.kv_world_size = self.config.kv_world_size
         if device is None:
             self.device = self._select_device(self.config.kv_buffer_device)
         else:
             self.device = self._select_device(device)
 
         # build distributed connection and send/recv implementation
+        logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset)
         self.group = StatelessProcessGroup.create(
             host=self.config.kv_ip,
             port=self.config.kv_port + port_offset,
-            rank=self.kv_rank,
-            world_size=self.kv_parallel_size,
+            rank=self.kv_group_rank,
+            world_size=self.kv_world_size,
         )
         # add a barrier to make sure the connection is initiated properly
         self.group.barrier()
         impl = self._get_device_send_recv_impl(self.group)
         self.device_send_func, self.device_recv_func = impl
-        # set target rank
-        self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
-        self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
 
         # transportation-related variables
         self.transport_thread: Optional[ThreadPoolExecutor] = None
@@ -145,16 +145,16 @@ class PyNcclPipe(KVPipeBase):
                            dtype=metadata["dtype"],
                            device=self.device)
 
-    def _send_metadata(self, metadata: Metadata):
+    def _send_metadata(self, metadata: Metadata, target_rank: int):
         """
         Send the metadata dictionary to the target rank.
 
         Parameters:
             - metadata: A dictionary with keys "dtype" and "shape".
         """
-        self.group.send_obj(metadata, self.target_rank_for_send)
+        self.group.send_obj(metadata, target_rank)
 
-    def _recv_metadata(self) -> Metadata:
+    def _recv_metadata(self, src_rank: int) -> Metadata:
         """
         Receive the metadata dictionary from the target rank.
 
@@ -162,9 +162,9 @@ class PyNcclPipe(KVPipeBase):
             - metadata: A dictionary with keys "dtype" and "shape" describing 
               the tensor.
         """
-        return self.group.recv_obj(self.target_rank_for_recv)
+        return self.group.recv_obj(src_rank)
 
-    def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
+    def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None:
         """
         The actual implementation of sending the tensor and its metadata to the 
         target rank.
@@ -174,12 +174,12 @@ class PyNcclPipe(KVPipeBase):
               being sent.
         """
         metadata = self._make_metadata(tensor)
-        self._send_metadata(metadata)
+        self._send_metadata(metadata, target_rank)
         if tensor is not None:
             self.device_send_func(tensor.to(self.device),
-                                  self.target_rank_for_send)
+                                  target_rank)
 
-    def _recv_impl(self) -> Optional[torch.Tensor]:
+    def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]:
         """
         The actual implementation of receiving a tensor and its metadata from 
         the target rank.
@@ -187,21 +187,22 @@ class PyNcclPipe(KVPipeBase):
         Returns:
             - buffer: The received tensor, or None if no tensor is received.
         """
-        metadata = self._recv_metadata()
+        metadata = self._recv_metadata(src_rank)
         if metadata["dtype"] is None:
             return None
         buffer = self._prepare_recv_buffer(metadata)
-        self.device_recv_func(buffer, self.target_rank_for_recv)
+        self.device_recv_func(buffer, src_rank)
 
         return buffer
 
     def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
-                            tensor_size: int) -> None:
+                            tensor_size: int,
+                            target_rank: int) -> None:
         """
         Wrapper for _send_impl to handle exceptions and update buffer size.
         """
         try:
-            self._send_impl(tensor)
+            self._send_impl(tensor, target_rank)
 
             with self.buffer_size_lock:
                 self.buffer_size -= tensor_size
@@ -220,7 +221,7 @@ class PyNcclPipe(KVPipeBase):
             logger.debug("KV cache transfer pipe is full. Waiting...")
             time.sleep(0.05)
 
-    def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
+    def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None:
         """
         Sends a tensor and its metadata to the destination rank in a 
         non-blocking way.
@@ -228,6 +229,7 @@ class PyNcclPipe(KVPipeBase):
         Parameters:
             - tensor: The tensor to send, or None if no tensor is being sent.
         """
+        logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank)
         if self.transport_thread is None:
             self.transport_thread = ThreadPoolExecutor(max_workers=1)
 
@@ -242,19 +244,23 @@ class PyNcclPipe(KVPipeBase):
             self.buffer_size += tensor_size
 
         self.transport_thread.submit(self.send_tensor_wrapper, tensor,
-                                     tensor_size)
+                                     tensor_size,
+                                     target_rank)
 
-    def recv_tensor(self) -> Optional[torch.Tensor]:
+    def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
         """
         Receives a tensor and its metadata from the source rank. Blocking call.
 
         Returns:
             - tensor: The received tensor, or None if no tensor is received.
         """
+
+        logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank)
+
         if self.transport_thread is None:
             self.transport_thread = ThreadPoolExecutor(max_workers=1)
 
-        future = self.transport_thread.submit(self._recv_impl)
+        future = self.transport_thread.submit(self._recv_impl, src_rank)
 
         try:
             tensor = future.result()
diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py
index 1e80e0bd..cd90206f 100644
--- a/vllm/distributed/kv_transfer/kv_transfer_agent.py
+++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py
@@ -35,6 +35,7 @@ class KVTransferAgent:
         rank: int,
         local_rank: int,
         config: "VllmConfig",
+        world_group,
     ):
 
         self.config = config
@@ -47,7 +48,7 @@ class KVTransferAgent:
             "TransferAgent should only be used when kv_connector is set."
 
         self.connector = KVConnectorFactory.create_connector(
-            rank, local_rank, config)
+            rank, local_rank, config, world_group)
 
     def send_kv_caches_and_hidden_states(
         self,
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 321902d1..b8937ef8 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -1085,7 +1085,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
         _KV_TRANSFER = kv_transfer.KVTransferAgent(
             rank=get_world_group().rank,
             local_rank=get_world_group().local_rank,
-            config=vllm_config)
+            config=vllm_config,
+            world_group=get_world_group())
 
 
 def ensure_model_parallel_initialized(
892
893
894
895
896
897
898
899
900
901
902
903
904
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 773f5abe..3eefd266 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -585,6 +585,8 @@ class DeepseekV2Model(nn.Module):
         cache_config = vllm_config.cache_config
         quant_config = vllm_config.quant_config
 
+        self.config = config
+
         self.padding_idx = config.pad_token_id
         self.vocab_size = config.vocab_size